def test_variable_replace_getter(self):
        with self.test_session() as sess:
            context = variable_replace.VariableReplaceGetter()
            mod = snt.Linear(1, custom_getter=context)

            inp_data = tf.ones([10, 1])
            with context.use_variables():
                y1 = mod(inp_data)
                sess.run(tf.initialize_all_variables())
                np_y1 = sess.run(y1)

            values = context.get_variable_dict()

            new_values = {k: v + 1 for k, v in values.items()}

            with context.use_value_dict(new_values):
                np_y2 = mod(inp_data).eval()

            self.assertNear((np_y2 - np_y1)[0], 2, 1e-8)

            for v in values.values():
                v.assign(v + 1).eval()

            with context.use_variables():
                np_y3 = mod(inp_data).eval()
                self.assertNear((np_y3 - np_y2)[0], 0, 1e-8)
예제 #2
0
    def __init__(self, module_fn, name="BaseModel"):
        """Initialize a _BaseModel that wraps the module_fn.

    Args:
      module_fn: Function that returns a sonnet module that will be wrapped.
      name: Name of this sonnet module.
    """
        self.context = variable_replace.VariableReplaceGetter(verbose=False)
        super(_BaseModel, self).__init__(name=name, custom_getter=self.context)

        with self._enter_variable_scope():
            self.mod = module_fn()