def test_variable_replace_getter(self): with self.test_session() as sess: context = variable_replace.VariableReplaceGetter() mod = snt.Linear(1, custom_getter=context) inp_data = tf.ones([10, 1]) with context.use_variables(): y1 = mod(inp_data) sess.run(tf.initialize_all_variables()) np_y1 = sess.run(y1) values = context.get_variable_dict() new_values = {k: v + 1 for k, v in values.items()} with context.use_value_dict(new_values): np_y2 = mod(inp_data).eval() self.assertNear((np_y2 - np_y1)[0], 2, 1e-8) for v in values.values(): v.assign(v + 1).eval() with context.use_variables(): np_y3 = mod(inp_data).eval() self.assertNear((np_y3 - np_y2)[0], 0, 1e-8)
def __init__(self, module_fn, name="BaseModel"): """Initialize a _BaseModel that wraps the module_fn. Args: module_fn: Function that returns a sonnet module that will be wrapped. name: Name of this sonnet module. """ self.context = variable_replace.VariableReplaceGetter(verbose=False) super(_BaseModel, self).__init__(name=name, custom_getter=self.context) with self._enter_variable_scope(): self.mod = module_fn()