Ejemplo n.º 1
0
    def testCollections(self):
        def fn(x):
            v = variables.VariableV1(3,
                                     name='v',
                                     trainable=False,
                                     collections=['a'])
            v2 = variable_scope.get_variable('v',
                                             initializer=init_ops.Constant(4),
                                             shape=[],
                                             dtype=dtypes.int32,
                                             collections=['a', 'b'])
            return v + v2 + x

        def assert_collections(graph):
            self.assertLen(
                graph.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES), 1)
            self.assertLen(graph.get_collection('a'), 2)
            self.assertLen(graph.get_collection('b'), 1)

        g = wrap_function.WrappedGraph()
        g.wrap_function(fn, [tensor_spec.TensorSpec([], dtypes.int32)])
        assert_collections(g.graph)

        def assert_fn():
            assert_collections(ops.get_default_graph())
            return 1  # Return is required

        # Assert that collections are accessible within a wrapped function.
        g.wrap_function(assert_fn, [])
Ejemplo n.º 2
0
  def testShareVariablesSameGraph(self):

    def add_v1(x):
      with variable_scope.variable_scope(
          'reuse', reuse=variable_scope.AUTO_REUSE):
        v = variable_scope.get_variable(
            'v', initializer=init_ops.Constant(3), shape=[], dtype=dtypes.int32)
      return v + x

    def subtract_v1(x):
      with variable_scope.variable_scope(
          'reuse', reuse=variable_scope.AUTO_REUSE):
        v = variable_scope.get_variable(
            'v', initializer=init_ops.Constant(4), shape=[], dtype=dtypes.int32)
      return v - x

    def different_variable_fn_v1(x):
      with variable_scope.variable_scope(
          'no_reuse', reuse=variable_scope.AUTO_REUSE):
        v = variable_scope.get_variable(
            'v', initializer=init_ops.Constant(5), shape=[], dtype=dtypes.int32)
      return v * x

    def increment_variable_v1(x):
      with variable_scope.variable_scope(
          'reuse', reuse=variable_scope.AUTO_REUSE):
        v = variable_scope.get_variable(
            'v', initializer=init_ops.Constant(6), shape=[], dtype=dtypes.int32)
      return v.assign_add(x)

    g = wrap_function.WrappedGraph()
    signature = [tensor_spec.TensorSpec([], dtypes.int32)]
    add = g.wrap_function(add_v1, signature)
    subtract = g.wrap_function(subtract_v1, signature)
    different_variable_fn = g.wrap_function(different_variable_fn_v1, signature)
    increment_variable = g.wrap_function(increment_variable_v1, signature)

    self.assertEqual(10, add(constant_op.constant(7)).numpy())
    self.assertEqual(35, different_variable_fn(constant_op.constant(7)).numpy())

    # The shared variable has a starting value of 3 because add_v1 was wrapped
    # first.
    self.assertEqual(-4, subtract(constant_op.constant(7)).numpy())
    self.assertEqual(10, increment_variable(constant_op.constant(7)).numpy())

    # Check that variable updates
    self.assertEqual(17, add(constant_op.constant(7)).numpy())
    self.assertEqual(3, subtract(constant_op.constant(7)).numpy())

    # Sanity check - result from this function shouldn't change.
    self.assertEqual(35, different_variable_fn(constant_op.constant(7)).numpy())

    self.assertAllEqual({'reuse/v', 'no_reuse/v'}, set(g.variables.keys()))
Ejemplo n.º 3
0
    def testReturnOp(self):
        def update_var_v1(x):
            v = variables.Variable(3, name='v')
            update_op = state_ops.assign(v, x).op
            return update_op

        g = wrap_function.WrappedGraph()
        signature = [tensor_spec.TensorSpec([], dtypes.int32)]
        update_var = g.wrap_function(update_var_v1, signature)

        self.assertEqual(g.variables['v'].numpy(), 3)
        update_var(constant_op.constant(12))
        self.assertEqual(g.variables['v'].numpy(), 12)
Ejemplo n.º 4
0
    def testShareVariablesDifferentGraphs(self):
        def add_v1(x):
            v = variables.Variable(3, name='v')
            return v + x

        def subtract_v1(x):
            v = variables.Variable(4, name='v')
            return v - x

        def different_variable_fn_v1(x):
            with ops.name_scope('different_scope'):
                v = variables.Variable(5, name='v')
            return v * x

        def increment_variable_v1(x):
            v = variables.Variable(6, name='v')
            return v.assign_add(x)

        signature = [tensor_spec.TensorSpec([], dtypes.int32)]
        vh = wrap_function.VariableHolder(share_variables=True)
        new_graph = lambda: wrap_function.WrappedGraph(variable_holder=vh)

        add = new_graph().wrap_function(add_v1, signature)
        subtract = new_graph().wrap_function(subtract_v1, signature)
        different_variable_fn = new_graph().wrap_function(
            different_variable_fn_v1, signature)
        increment_variable = new_graph().wrap_function(increment_variable_v1,
                                                       signature)

        self.assertEqual(10, add(constant_op.constant(7)).numpy())
        self.assertEqual(
            35,
            different_variable_fn(constant_op.constant(7)).numpy())

        # Because the variable in add_v1 was created first, its starting value is 3
        # instead of the values defined in subtract_v1 or increment_variable_v1.
        self.assertEqual(-4, subtract(constant_op.constant(7)).numpy())
        self.assertEqual(10,
                         increment_variable(constant_op.constant(7)).numpy())

        # Check that variable updates
        self.assertEqual(17, add(constant_op.constant(7)).numpy())
        self.assertEqual(3, subtract(constant_op.constant(7)).numpy())

        # Sanity check - result from this function shouldn't change.
        self.assertEqual(
            35,
            different_variable_fn(constant_op.constant(7)).numpy())

        self.assertAllEqual({'v', 'different_scope/v'},
                            set(vh.variables.keys()))
Ejemplo n.º 5
0
  def testAddFunction(self):

    def fn(x):
      v = variables.Variable(3, name='v')
      v2 = variable_scope.get_variable(
          'v', initializer=init_ops.Constant(4), shape=[], dtype=dtypes.int32)
      return v + v2 + x

    with self.cached_session() as sess:
      result = fn(constant_op.constant(5))
      sess.run(variables.global_variables_initializer())
      expected = sess.run(result)

    g = wrap_function.WrappedGraph()
    signature = [tensor_spec.TensorSpec([], dtypes.int32)]
    wrapped_fn = g.wrap_function(fn, signature)
    self.assertEqual(expected, wrapped_fn(constant_op.constant(5)).numpy())