예제 #1
0
    def testSaveRestoreGraphCallable(self):
        with ops.device(self._dev()):

            @graph_callable.graph_callable(
                [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
            def model(x):
                v = variable_scope.get_variable(
                    'v', initializer=init_ops.zeros_initializer(), shape=())
                return v + x

            # Default 2 + 0 = 2
            self.assertEqual(
                2,
                model(array_ops.constant(2, dtype=dtypes.float32)).numpy())

            # Save the variable value 0.
            ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt')
            _saver.Saver(model.variables).save(ckpt_prefix)

            # update variable to 1, so that 2 + 1 = 3
            model.variables[0].assign(1.)
            self.assertEqual(
                3,
                model(array_ops.constant(2, dtype=dtypes.float32)).numpy())

            # load the variable value 0, so that 2 + 0 = 2
            _saver.Saver(model.variables).restore(ckpt_prefix)
            self.assertEqual(
                2,
                model(array_ops.constant(2, dtype=dtypes.float32)).numpy())

            # update checkpoint variable to 1 and memory value to 2.
            model.variables[0].assign(1.)
            _saver.Saver(model.variables).save(ckpt_prefix)
            model.variables[0].assign(2.)
            self.assertEqual(
                4,
                model(array_ops.constant(2, dtype=dtypes.float32)).numpy())

            # reset the graph and reload on create, so that 1 + 2 = 3
            with ops.Graph().as_default():
                with _saver.restore_variables_on_create(ckpt_prefix):

                    @graph_callable.graph_callable([
                        graph_callable.ShapeAndDtype(shape=(),
                                                     dtype=dtypes.float32)
                    ])
                    def model2(x):
                        v = variable_scope.get_variable(
                            'v',
                            initializer=init_ops.zeros_initializer(),
                            shape=())
                        return v + x

                    self.assertEqual(
                        3,
                        model2(array_ops.constant(
                            2, dtype=dtypes.float32)).numpy())
예제 #2
0
    def testPureFunction(self):
        @graph_callable.graph_callable(
            [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.int32)])
        def f(x):
            return math_ops.add(x, constant_op.constant(3))

        self.assertAllEqual(5, f(constant_op.constant(2)))
예제 #3
0
    def testPureFunction(self):
        @graph_callable.graph_callable(
            [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.int32)])
        def f(x):
            return math_ops.add(x, tensor.Tensor(3))

        self.assertAllEqual(5, f(tensor.Tensor(2)).numpy())
예제 #4
0
    def DISABLED_testRepeatedUseOfSubFunction(self):
        @function.Defun(dtypes.int32, dtypes.int32)
        def add(a, b):
            return math_ops.add(a, b)

        @graph_callable.graph_callable(
            [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.int32)])
        def add_one(x):
            return add(x, 1)

        @graph_callable.graph_callable(
            [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.int32)])
        def add_two(x):
            return add(x, 2)

        two = constant_op.constant(2)
        self.assertAllEqual(3, add_one(two))
        self.assertAllEqual(4, add_two(two))
예제 #5
0
    def testFunctionWithoutReturnValue(self):
        @graph_callable.graph_callable(
            [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
        def my_function(x):
            v = variable_scope.get_variable(
                "v", initializer=init_ops.zeros_initializer(), shape=())
            v.assign(x)

        my_function(constant_op.constant(4, dtype=dtypes.float32))
        self.assertAllEqual(4, my_function.variables[0].read_value())
예제 #6
0
    def testEmptyInitializer(self):
        @graph_callable.graph_callable(
            [graph_callable.ShapeAndDtype(shape=(1), dtype=dtypes.float32)])
        def my_function(x):
            v = variable_scope.get_variable("v", shape=[1])
            return x + 0 * v

        self.assertEqual(
            [2.],
            my_function(constant_op.constant([2.],
                                             dtype=dtypes.float32)).numpy())
예제 #7
0
    def testUpdatesAreOrdered(self):
        @graph_callable.graph_callable(
            [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
        def my_function(x):
            v = variable_scope.get_variable(
                "v", initializer=init_ops.zeros_initializer(), shape=())
            v.assign(x + 1)
            v.assign(v * x)
            return v.read_value()

        self.assertAllEqual(my_function(constant_op.constant(2.0)), 6.0)
예제 #8
0
 def testMismatchingNumArgs(self):
   # pylint: disable=anomalous-backslash-in-string
   with self.assertRaisesRegexp(TypeError,
                                "The number of arguments accepted by the "
                                "decorated function `my_function` \(2\) must "
                                "match the number of ShapeAndDtype objects "
                                "passed to the graph_callable\(\) decorator "
                                "\(1\)."):
     @graph_callable.graph_callable([
         graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
     def my_function(x, y):  # pylint: disable=unused-variable
       return x + y
예제 #9
0
  def testIncorrectlyShapedInputs(self):
    @graph_callable.graph_callable(
        [graph_callable.ShapeAndDtype(shape=(3), dtype=dtypes.float32)])
    def my_function(x):
      v = variable_scope.get_variable(
          "v", initializer=init_ops.zeros_initializer(), shape=())
      return v + x

    with self.assertRaises(ValueError):
      my_function([1, 2])

    self.assertTrue(([1, 2, 3] == my_function(
        constant_op.constant([1, 2, 3], dtype=dtypes.float32)).numpy()).all())
예제 #10
0
    def testNestedFunction(self):
        # TensorFlow function (which is what would be used in TensorFlow graph
        # construction).
        @function.Defun(dtypes.int32, dtypes.int32)
        def add(a, b):
            return math_ops.add(a, b)

        # A graph_callable that will invoke the TensorFlow function.
        @graph_callable.graph_callable(
            [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.int32)])
        def add_one(x):
            return add(x, 1)

        self.assertAllEqual(3, add_one(constant_op.constant(2)))
예제 #11
0
  def testTensorShape(self):

    @graph_callable.graph_callable(
        [graph_callable.ShapeAndDtype(shape=(1), dtype=dtypes.float32)])
    def my_function(x):
      _ = x.get_shape()
      v = variable_scope.get_variable(
          "v", initializer=init_ops.zeros_initializer(), shape=[x.shape[0]])
      return v + x

    self.assertEqual([2.],
                     my_function(
                         constant_op.constant([2.],
                                              dtype=dtypes.float32)).numpy())
예제 #12
0
  def testBasic(self):

    @graph_callable.graph_callable(
        [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
    def my_function(x):
      v = variable_scope.get_variable(
          "v", initializer=init_ops.zeros_initializer(), shape=())
      return v + x

    self.assertEqual(
        2, my_function(constant_op.constant(2, dtype=dtypes.float32)).numpy())

    my_function.variables[0].assign(1.)
    self.assertEqual(
        3, my_function(constant_op.constant(2, dtype=dtypes.float32)).numpy())
예제 #13
0
  def testNestedSequenceInputs(self):
    sd = graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)
    @graph_callable.graph_callable([[sd, tuple([sd, sd]), sd]])
    def my_op(inputs):
      a, b, c = inputs
      e, f = b
      v = variable_scope.get_variable(
          "my_v", initializer=init_ops.zeros_initializer(), shape=())
      return [a + a + v, tuple([e + e, f + f]), c + c], a + e + f + c + v

    inputs = [constant_op.constant(1.),
              [constant_op.constant(2.), constant_op.constant(3.)],
              constant_op.constant(4.)]
    ret = my_op(inputs)
    self.assertEqual(len(ret), 2.)
    self.assertEqual(ret[1].numpy(), 10.)

    my_op.variables[0].assign(1.)
    ret = my_op(inputs)
    self.assertEqual(ret[1].numpy(), 11.)