def testSaveRestoreGraphCallable(self): with ops.device(self._dev()): @graph_callable.graph_callable( [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) def model(x): v = variable_scope.get_variable( 'v', initializer=init_ops.zeros_initializer(), shape=()) return v + x # Default 2 + 0 = 2 self.assertEqual( 2, model(array_ops.constant(2, dtype=dtypes.float32)).numpy()) # Save the variable value 0. ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') _saver.Saver(model.variables).save(ckpt_prefix) # update variable to 1, so that 2 + 1 = 3 model.variables[0].assign(1.) self.assertEqual( 3, model(array_ops.constant(2, dtype=dtypes.float32)).numpy()) # load the variable value 0, so that 2 + 0 = 2 _saver.Saver(model.variables).restore(ckpt_prefix) self.assertEqual( 2, model(array_ops.constant(2, dtype=dtypes.float32)).numpy()) # update checkpoint variable to 1 and memory value to 2. model.variables[0].assign(1.) _saver.Saver(model.variables).save(ckpt_prefix) model.variables[0].assign(2.) self.assertEqual( 4, model(array_ops.constant(2, dtype=dtypes.float32)).numpy()) # reset the graph and reload on create, so that 1 + 2 = 3 with ops.Graph().as_default(): with _saver.restore_variables_on_create(ckpt_prefix): @graph_callable.graph_callable([ graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32) ]) def model2(x): v = variable_scope.get_variable( 'v', initializer=init_ops.zeros_initializer(), shape=()) return v + x self.assertEqual( 3, model2(array_ops.constant( 2, dtype=dtypes.float32)).numpy())
def testPureFunction(self): @graph_callable.graph_callable( [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.int32)]) def f(x): return math_ops.add(x, constant_op.constant(3)) self.assertAllEqual(5, f(constant_op.constant(2)))
def testPureFunction(self): @graph_callable.graph_callable( [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.int32)]) def f(x): return math_ops.add(x, tensor.Tensor(3)) self.assertAllEqual(5, f(tensor.Tensor(2)).numpy())
def DISABLED_testRepeatedUseOfSubFunction(self): @function.Defun(dtypes.int32, dtypes.int32) def add(a, b): return math_ops.add(a, b) @graph_callable.graph_callable( [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.int32)]) def add_one(x): return add(x, 1) @graph_callable.graph_callable( [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.int32)]) def add_two(x): return add(x, 2) two = constant_op.constant(2) self.assertAllEqual(3, add_one(two)) self.assertAllEqual(4, add_two(two))
def testFunctionWithoutReturnValue(self): @graph_callable.graph_callable( [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) def my_function(x): v = variable_scope.get_variable( "v", initializer=init_ops.zeros_initializer(), shape=()) v.assign(x) my_function(constant_op.constant(4, dtype=dtypes.float32)) self.assertAllEqual(4, my_function.variables[0].read_value())
def testEmptyInitializer(self): @graph_callable.graph_callable( [graph_callable.ShapeAndDtype(shape=(1), dtype=dtypes.float32)]) def my_function(x): v = variable_scope.get_variable("v", shape=[1]) return x + 0 * v self.assertEqual( [2.], my_function(constant_op.constant([2.], dtype=dtypes.float32)).numpy())
def testUpdatesAreOrdered(self): @graph_callable.graph_callable( [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) def my_function(x): v = variable_scope.get_variable( "v", initializer=init_ops.zeros_initializer(), shape=()) v.assign(x + 1) v.assign(v * x) return v.read_value() self.assertAllEqual(my_function(constant_op.constant(2.0)), 6.0)
def testMismatchingNumArgs(self): # pylint: disable=anomalous-backslash-in-string with self.assertRaisesRegexp(TypeError, "The number of arguments accepted by the " "decorated function `my_function` \(2\) must " "match the number of ShapeAndDtype objects " "passed to the graph_callable\(\) decorator " "\(1\)."): @graph_callable.graph_callable([ graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) def my_function(x, y): # pylint: disable=unused-variable return x + y
def testIncorrectlyShapedInputs(self): @graph_callable.graph_callable( [graph_callable.ShapeAndDtype(shape=(3), dtype=dtypes.float32)]) def my_function(x): v = variable_scope.get_variable( "v", initializer=init_ops.zeros_initializer(), shape=()) return v + x with self.assertRaises(ValueError): my_function([1, 2]) self.assertTrue(([1, 2, 3] == my_function( constant_op.constant([1, 2, 3], dtype=dtypes.float32)).numpy()).all())
def testNestedFunction(self): # TensorFlow function (which is what would be used in TensorFlow graph # construction). @function.Defun(dtypes.int32, dtypes.int32) def add(a, b): return math_ops.add(a, b) # A graph_callable that will invoke the TensorFlow function. @graph_callable.graph_callable( [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.int32)]) def add_one(x): return add(x, 1) self.assertAllEqual(3, add_one(constant_op.constant(2)))
def testTensorShape(self): @graph_callable.graph_callable( [graph_callable.ShapeAndDtype(shape=(1), dtype=dtypes.float32)]) def my_function(x): _ = x.get_shape() v = variable_scope.get_variable( "v", initializer=init_ops.zeros_initializer(), shape=[x.shape[0]]) return v + x self.assertEqual([2.], my_function( constant_op.constant([2.], dtype=dtypes.float32)).numpy())
def testBasic(self): @graph_callable.graph_callable( [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) def my_function(x): v = variable_scope.get_variable( "v", initializer=init_ops.zeros_initializer(), shape=()) return v + x self.assertEqual( 2, my_function(constant_op.constant(2, dtype=dtypes.float32)).numpy()) my_function.variables[0].assign(1.) self.assertEqual( 3, my_function(constant_op.constant(2, dtype=dtypes.float32)).numpy())
def testNestedSequenceInputs(self): sd = graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32) @graph_callable.graph_callable([[sd, tuple([sd, sd]), sd]]) def my_op(inputs): a, b, c = inputs e, f = b v = variable_scope.get_variable( "my_v", initializer=init_ops.zeros_initializer(), shape=()) return [a + a + v, tuple([e + e, f + f]), c + c], a + e + f + c + v inputs = [constant_op.constant(1.), [constant_op.constant(2.), constant_op.constant(3.)], constant_op.constant(4.)] ret = my_op(inputs) self.assertEqual(len(ret), 2.) self.assertEqual(ret[1].numpy(), 10.) my_op.variables[0].assign(1.) ret = my_op(inputs) self.assertEqual(ret[1].numpy(), 11.)