def testCollections(self): def fn(x): v = variables.VariableV1(3, name='v', trainable=False, collections=['a']) v2 = variable_scope.get_variable('v', initializer=init_ops.Constant(4), shape=[], dtype=dtypes.int32, collections=['a', 'b']) return v + v2 + x def assert_collections(graph): self.assertLen( graph.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES), 1) self.assertLen(graph.get_collection('a'), 2) self.assertLen(graph.get_collection('b'), 1) g = wrap_function.WrappedGraph() g.wrap_function(fn, [tensor_spec.TensorSpec([], dtypes.int32)]) assert_collections(g.graph) def assert_fn(): assert_collections(ops.get_default_graph()) return 1 # Return is required # Assert that collections are accessible within a wrapped function. g.wrap_function(assert_fn, [])
def testShareVariablesSameGraph(self): def add_v1(x): with variable_scope.variable_scope( 'reuse', reuse=variable_scope.AUTO_REUSE): v = variable_scope.get_variable( 'v', initializer=init_ops.Constant(3), shape=[], dtype=dtypes.int32) return v + x def subtract_v1(x): with variable_scope.variable_scope( 'reuse', reuse=variable_scope.AUTO_REUSE): v = variable_scope.get_variable( 'v', initializer=init_ops.Constant(4), shape=[], dtype=dtypes.int32) return v - x def different_variable_fn_v1(x): with variable_scope.variable_scope( 'no_reuse', reuse=variable_scope.AUTO_REUSE): v = variable_scope.get_variable( 'v', initializer=init_ops.Constant(5), shape=[], dtype=dtypes.int32) return v * x def increment_variable_v1(x): with variable_scope.variable_scope( 'reuse', reuse=variable_scope.AUTO_REUSE): v = variable_scope.get_variable( 'v', initializer=init_ops.Constant(6), shape=[], dtype=dtypes.int32) return v.assign_add(x) g = wrap_function.WrappedGraph() signature = [tensor_spec.TensorSpec([], dtypes.int32)] add = g.wrap_function(add_v1, signature) subtract = g.wrap_function(subtract_v1, signature) different_variable_fn = g.wrap_function(different_variable_fn_v1, signature) increment_variable = g.wrap_function(increment_variable_v1, signature) self.assertEqual(10, add(constant_op.constant(7)).numpy()) self.assertEqual(35, different_variable_fn(constant_op.constant(7)).numpy()) # The shared variable has a starting value of 3 because add_v1 was wrapped # first. self.assertEqual(-4, subtract(constant_op.constant(7)).numpy()) self.assertEqual(10, increment_variable(constant_op.constant(7)).numpy()) # Check that variable updates self.assertEqual(17, add(constant_op.constant(7)).numpy()) self.assertEqual(3, subtract(constant_op.constant(7)).numpy()) # Sanity check - result from this function shouldn't change. self.assertEqual(35, different_variable_fn(constant_op.constant(7)).numpy()) self.assertAllEqual({'reuse/v', 'no_reuse/v'}, set(g.variables.keys()))
def testReturnOp(self): def update_var_v1(x): v = variables.Variable(3, name='v') update_op = state_ops.assign(v, x).op return update_op g = wrap_function.WrappedGraph() signature = [tensor_spec.TensorSpec([], dtypes.int32)] update_var = g.wrap_function(update_var_v1, signature) self.assertEqual(g.variables['v'].numpy(), 3) update_var(constant_op.constant(12)) self.assertEqual(g.variables['v'].numpy(), 12)
def testShareVariablesDifferentGraphs(self): def add_v1(x): v = variables.Variable(3, name='v') return v + x def subtract_v1(x): v = variables.Variable(4, name='v') return v - x def different_variable_fn_v1(x): with ops.name_scope('different_scope'): v = variables.Variable(5, name='v') return v * x def increment_variable_v1(x): v = variables.Variable(6, name='v') return v.assign_add(x) signature = [tensor_spec.TensorSpec([], dtypes.int32)] vh = wrap_function.VariableHolder(share_variables=True) new_graph = lambda: wrap_function.WrappedGraph(variable_holder=vh) add = new_graph().wrap_function(add_v1, signature) subtract = new_graph().wrap_function(subtract_v1, signature) different_variable_fn = new_graph().wrap_function( different_variable_fn_v1, signature) increment_variable = new_graph().wrap_function(increment_variable_v1, signature) self.assertEqual(10, add(constant_op.constant(7)).numpy()) self.assertEqual( 35, different_variable_fn(constant_op.constant(7)).numpy()) # Because the variable in add_v1 was created first, its starting value is 3 # instead of the values defined in subtract_v1 or increment_variable_v1. self.assertEqual(-4, subtract(constant_op.constant(7)).numpy()) self.assertEqual(10, increment_variable(constant_op.constant(7)).numpy()) # Check that variable updates self.assertEqual(17, add(constant_op.constant(7)).numpy()) self.assertEqual(3, subtract(constant_op.constant(7)).numpy()) # Sanity check - result from this function shouldn't change. self.assertEqual( 35, different_variable_fn(constant_op.constant(7)).numpy()) self.assertAllEqual({'v', 'different_scope/v'}, set(vh.variables.keys()))
def testAddFunction(self): def fn(x): v = variables.Variable(3, name='v') v2 = variable_scope.get_variable( 'v', initializer=init_ops.Constant(4), shape=[], dtype=dtypes.int32) return v + v2 + x with self.cached_session() as sess: result = fn(constant_op.constant(5)) sess.run(variables.global_variables_initializer()) expected = sess.run(result) g = wrap_function.WrappedGraph() signature = [tensor_spec.TensorSpec([], dtypes.int32)] wrapped_fn = g.wrap_function(fn, signature) self.assertEqual(expected, wrapped_fn(constant_op.constant(5)).numpy())