def wrapper(obj, *args, **kwargs): def default_context_manager(reuse=None): variable_scope = obj.variable_scope return tf.variable_scope(variable_scope, reuse=reuse) variable_scope_context_manager = getattr(obj, "_enter_variable_scope", default_context_manager) graph = tf.get_default_graph() # Temporarily enter the variable scope to capture it with variable_scope_context_manager() as tmp_variable_scope: variable_scope = tmp_variable_scope with variable_scope_ops._pure_variable_scope( variable_scope, reuse=tf.AUTO_REUSE) as pure_variable_scope: name_scope = variable_scope.original_name_scope if name_scope[-1] != "/": name_scope += "/" with tf.name_scope(name_scope): sub_scope = snt_util.to_snake_case(method.__name__) with tf.name_scope(sub_scope) as scope: out_ops = method(obj, *args, **kwargs) return out_ops
def testToSnakeCase(self): test_cases = [ ("UpperCamelCase", "upper_camel_case"), ("lowerCamelCase", "lower_camel_case"), ("endsWithXYZ", "ends_with_xyz"), ("already_snake_case", "already_snake_case"), ("__private__", "private"), ("LSTMModule", "lstm_module"), ("version123p56vfxObject", "version_123p56vfx_object"), ("version123P56VFXObject", "version_123p56vfx_object"), ("versionVFX123P56Object", "version_vfx123p56_object"), ("versionVfx123P56Object", "version_vfx_123p56_object"), ("lstm1", "lstm_1"), ("LSTM1", "lstm1"), ] for camel_case, snake_case in test_cases: actual = util.to_snake_case(camel_case) self.assertEqual(actual, snake_case, "_to_snake_case(%s) -> %s != %s" % (camel_case, actual, snake_case))
def __init__(self, _sentinel=None, custom_getter=None, name=None): # pylint: disable=invalid-name """Performs the initialisation necessary for all AbstractModule instances. Every subclass of AbstractModule must begin their constructor with a call to this constructor, i.e. `super(MySubModule, self).__init__(custom_getter=custom_getter, name=name)`. If you instantiate sub-modules in __init__ you must create them within the `_enter_variable_scope` context manager to ensure they are in the module's variable scope. Alternatively, instantiate sub-modules in `_build`. Args: _sentinel: Variable that only carries a non-None value if `__init__` was called without named parameters. If this is the case, a deprecation warning is issued in form of a `ValueError`. custom_getter: Callable or dictionary of callables to use as custom getters inside the module. If a dictionary, the keys correspond to regexes to match variable names. See the `tf.get_variable` documentation for information about the custom_getter API. name: Name of this module. Used to construct the Templated build function. If `None` the module's class name is used (converted to snake case). Raises: TypeError: If `name` is not a string. TypeError: If a given `custom_getter` is not callable. ValueError: If `__init__` was called without named arguments. """ if _sentinel is not None: raise ValueError("Calling AbstractModule.__init__ without named " "arguments is not supported.") if name is None: name = util.to_snake_case(self.__class__.__name__) elif not isinstance(name, six.string_types): raise TypeError("Name must be a string, not {} of type {}.".format( name, type(name))) self._is_connected = False self._connected_subgraphs = [] # If the given custom getter is a dictionary with a per-variable custom # getter, wrap it into a single custom getter. if isinstance(custom_getter, collections.Mapping): self._custom_getter = util.custom_getter_router( custom_getter_map=custom_getter, name_fn=lambda name: name[len(self.scope_name) + 1:]) elif custom_getter is not None and not callable(custom_getter): raise TypeError("Given custom_getter is not callable.") else: self._custom_getter = custom_getter self._template = tf.make_template(name, self._build_wrapper, create_scope_now_=True, custom_getter_=self._custom_getter) self._original_name = name self._unique_name = self._template.variable_scope.name.split("/")[-1] # Update __call__ and the object docstrings to enable better introspection. self.__doc__ = self._build.__doc__ self.__call__.__func__.__doc__ = self._build.__doc__ # Keep track of which graph this module has been connected to. Sonnet # modules cannot be connected to multiple graphs, as transparent variable # sharing is impossible in that case. self._graph = None # Container for all variables created in this module and its sub-modules. self._all_variables = set([]) # Calling `.defun()` causes the module's call method to become wrapped as # a graph function. self._defun_wrapped = False
def __init__(self, _sentinel=None, custom_getter=None, name=None): # pylint: disable=invalid-name """Performs the initialisation necessary for all AbstractModule instances. Every subclass of AbstractModule must begin their constructor with a call to this constructor, i.e. `super(MySubModule, self).__init__(custom_getter=custom_getter, name=name)`. If you instantiate sub-modules in __init__ you must create them within the `_enter_variable_scope` context manager to ensure they are in the module's variable scope. Alternatively, instantiate sub-modules in `_build`. Args: _sentinel: Variable that only carries a non-None value if `__init__` was called without named parameters. If this is the case, a deprecation warning is issued in form of a `ValueError`. custom_getter: Callable or dictionary of callables to use as custom getters inside the module. If a dictionary, the keys correspond to regexes to match variable names. See the `tf.get_variable` documentation for information about the custom_getter API. name: Name of this module. Used to construct the Templated build function. If `None` the module's class name is used (converted to snake case). Raises: TypeError: If `name` is not a string. TypeError: If a given `custom_getter` is not callable. ValueError: If `__init__` was called without named arguments. """ if _sentinel is not None: raise ValueError("Calling AbstractModule.__init__ without named " "arguments is not supported.") if name is None: name = util.to_snake_case(self.__class__.__name__) elif not isinstance(name, six.string_types): raise TypeError("Name must be a string.") self._connected_subgraphs = [] # If the given custom getter is a dictionary with a per-variable custom # getter, wrap it into a single custom getter. if isinstance(custom_getter, collections.Mapping): self._custom_getter = util.custom_getter_router( custom_getter_map=custom_getter, name_fn=lambda name: name[len(self.scope_name) + 1:]) else: if not (custom_getter is None or callable(custom_getter)): raise TypeError("Given custom_getter is not callable.") self._custom_getter = custom_getter self._custom_getter = _maybe_wrap_custom_getter( _variable_tracking_custom_getter, self._custom_getter) self._template = tf.make_template(name, self._build_wrapper, create_scope_now_=True, custom_getter_=self._custom_getter) self._original_name = name self._unique_name = self._template.variable_scope.name.split("/")[-1] # Update __call__ and the object docstrings to enable better introspection. self.__doc__ = self._build.__doc__ self.__call__.__func__.__doc__ = self._build.__doc__ # Keep track of which graph this module has been connected to. Sonnet # modules cannot be connected to multiple graphs, as transparent variable # sharing is impossible in that case. self._graph = None # Container for all variables created in this module and its sub-modules. self._all_variables = set([])