Example #1
0
    def __init__(self, build, name=None):
        """Constructs a module with a given build function.

    The Module class can be used to wrap a function assembling a network into a
    module.

    For example, the following code implements a simple one-hidden-layer MLP
    model by defining a function called make_model and using a Module instance
    to wrap it.

    ```python
    def make_model(inputs):
      lin1 = snt.Linear(name="lin1", output_size=10)(inputs)
      relu1 = tf.nn.relu(lin1, name="relu1")
      lin2 = snt.Linear(name="lin2", output_size=20)(relu1)
      return lin2

    model = snt.Module(name='simple_mlp', build=make_model)
    outputs = model(inputs)
    ```

    The `partial` package from `functools` can be used to bake configuration
    parameters into the function at construction time, as shown in the following
    example.

    ```python
    from functools import partial

    def make_model(inputs, output_sizes):
      lin1 = snt.Linear(name="lin1", output_size=output_sizes[0])(inputs)
      relu1 = tf.nn.relu(lin1, name="relu1")
      lin2 = snt.Linear(name="lin2", output_size=output_sizes[1])(relu1)
      return lin2

    model = snt.Module(name='simple_mlp',
                       build=partial(make_model, output_size=[10, 20])
    outputs = model(inputs)
    ```

    Args:
      build: Callable to be invoked when connecting the module to the graph.
          The `build` function is invoked when the module is called, and its
          role is to specify how to add elements to the Graph, and how to
          compute output Tensors from input Tensors.
          The `build` function signature can include the following parameters:
            *args - Input Tensors.
            **kwargs - Additional Python parameters controlling connection.
      name: Module name. If set to `None` (the default), the name will be set to
          that of the `build` callable converted to `snake_case`. If `build` has
          no name, the name will be 'module'.

    Raises:
      TypeError: If build is not callable.
    """
        if not callable(build):
            raise TypeError("Input 'build' must be callable.")
        if name is None:
            name = util.name_for_callable(build)
        super(Module, self).__init__(name=name)
        self._build_function = build
Example #2
0
 def assertName(self, func, expected):
     name = util.name_for_callable(func)
     self.assertEqual(name, expected)
Example #3
0
  def __init__(self, build, custom_getter=None, name=None):
    """Constructs a module with a given build function.

    The Module class can be used to wrap a function assembling a network into a
    module.

    For example, the following code implements a simple one-hidden-layer MLP
    model by defining a function called make_model and using a Module instance
    to wrap it.

    ```python
    def make_model(inputs):
      lin1 = snt.Linear(name="lin1", output_size=10)(inputs)
      relu1 = tf.nn.relu(lin1, name="relu1")
      lin2 = snt.Linear(name="lin2", output_size=20)(relu1)
      return lin2

    model = snt.Module(name='simple_mlp', build=make_model)
    outputs = model(inputs)
    ```

    The `partial` package from `functools` can be used to bake configuration
    parameters into the function at construction time, as shown in the following
    example.

    ```python
    from functools import partial

    def make_model(inputs, output_sizes):
      lin1 = snt.Linear(name="lin1", output_size=output_sizes[0])(inputs)
      relu1 = tf.nn.relu(lin1, name="relu1")
      lin2 = snt.Linear(name="lin2", output_size=output_sizes[1])(relu1)
      return lin2

    model = snt.Module(name='simple_mlp',
                       build=partial(make_model, output_size=[10, 20])
    outputs = model(inputs)
    ```

    Args:
      build: Callable to be invoked when connecting the module to the graph.
          The `build` function is invoked when the module is called, and its
          role is to specify how to add elements to the Graph, and how to
          compute output Tensors from input Tensors.
          The `build` function signature can include the following parameters:
            *args - Input Tensors.
            **kwargs - Additional Python parameters controlling connection.
      custom_getter: Callable or dictionary of callables to use as
          custom getters inside the module. If a dictionary, the keys
          correspond to regexes to match variable names. See the
          `tf.get_variable` documentation for information about the
          custom_getter API.
      name: Module name. If set to `None` (the default), the name will be set to
          that of the `build` callable converted to `snake_case`. If `build` has
          no name, the name will be 'module'.

    Raises:
      TypeError: If build is not callable.
      TypeError: If a given `custom_getter` is not callable.
    """
    if not callable(build):
      raise TypeError("Input 'build' must be callable.")
    if name is None:
      name = util.name_for_callable(build)
    super(Module, self).__init__(custom_getter=custom_getter, name=name)
    self._build_function = build
Example #4
0
 def assertName(self, func, expected):
   name = util.name_for_callable(func)
   self.assertEqual(name, expected)