def __init__(self, build, name=None): """Constructs a module with a given build function. The Module class can be used to wrap a function assembling a network into a module. For example, the following code implements a simple one-hidden-layer MLP model by defining a function called make_model and using a Module instance to wrap it. ```python def make_model(inputs): lin1 = snt.Linear(name="lin1", output_size=10)(inputs) relu1 = tf.nn.relu(lin1, name="relu1") lin2 = snt.Linear(name="lin2", output_size=20)(relu1) return lin2 model = snt.Module(name='simple_mlp', build=make_model) outputs = model(inputs) ``` The `partial` package from `functools` can be used to bake configuration parameters into the function at construction time, as shown in the following example. ```python from functools import partial def make_model(inputs, output_sizes): lin1 = snt.Linear(name="lin1", output_size=output_sizes[0])(inputs) relu1 = tf.nn.relu(lin1, name="relu1") lin2 = snt.Linear(name="lin2", output_size=output_sizes[1])(relu1) return lin2 model = snt.Module(name='simple_mlp', build=partial(make_model, output_size=[10, 20]) outputs = model(inputs) ``` Args: build: Callable to be invoked when connecting the module to the graph. The `build` function is invoked when the module is called, and its role is to specify how to add elements to the Graph, and how to compute output Tensors from input Tensors. The `build` function signature can include the following parameters: *args - Input Tensors. **kwargs - Additional Python parameters controlling connection. name: Module name. If set to `None` (the default), the name will be set to that of the `build` callable converted to `snake_case`. If `build` has no name, the name will be 'module'. Raises: TypeError: If build is not callable. """ if not callable(build): raise TypeError("Input 'build' must be callable.") if name is None: name = util.name_for_callable(build) super(Module, self).__init__(name=name) self._build_function = build
def assertName(self, func, expected): name = util.name_for_callable(func) self.assertEqual(name, expected)
def __init__(self, build, custom_getter=None, name=None): """Constructs a module with a given build function. The Module class can be used to wrap a function assembling a network into a module. For example, the following code implements a simple one-hidden-layer MLP model by defining a function called make_model and using a Module instance to wrap it. ```python def make_model(inputs): lin1 = snt.Linear(name="lin1", output_size=10)(inputs) relu1 = tf.nn.relu(lin1, name="relu1") lin2 = snt.Linear(name="lin2", output_size=20)(relu1) return lin2 model = snt.Module(name='simple_mlp', build=make_model) outputs = model(inputs) ``` The `partial` package from `functools` can be used to bake configuration parameters into the function at construction time, as shown in the following example. ```python from functools import partial def make_model(inputs, output_sizes): lin1 = snt.Linear(name="lin1", output_size=output_sizes[0])(inputs) relu1 = tf.nn.relu(lin1, name="relu1") lin2 = snt.Linear(name="lin2", output_size=output_sizes[1])(relu1) return lin2 model = snt.Module(name='simple_mlp', build=partial(make_model, output_size=[10, 20]) outputs = model(inputs) ``` Args: build: Callable to be invoked when connecting the module to the graph. The `build` function is invoked when the module is called, and its role is to specify how to add elements to the Graph, and how to compute output Tensors from input Tensors. The `build` function signature can include the following parameters: *args - Input Tensors. **kwargs - Additional Python parameters controlling connection. custom_getter: Callable or dictionary of callables to use as custom getters inside the module. If a dictionary, the keys correspond to regexes to match variable names. See the `tf.get_variable` documentation for information about the custom_getter API. name: Module name. If set to `None` (the default), the name will be set to that of the `build` callable converted to `snake_case`. If `build` has no name, the name will be 'module'. Raises: TypeError: If build is not callable. TypeError: If a given `custom_getter` is not callable. """ if not callable(build): raise TypeError("Input 'build' must be callable.") if name is None: name = util.name_for_callable(build) super(Module, self).__init__(custom_getter=custom_getter, name=name) self._build_function = build