예제 #1
0
  def testVariablesOutsideAndCustomGradient(self):
    with ops.Graph().as_default():
      init = constant_op.constant(100.0, shape=(5,))
      var = variables.Variable(init, shape=(5,))

      @custom_gradient.custom_gradient
      def _MyOnesLike(x):
        """Dummy version of ones_like which defines a gradient."""

        output = array_ops.ones_like(x)

        def _Grad(dy):
          return array_ops.identity(dy)

        return output, _Grad

      def _Func(x):
        # non-differentiable operation with custom gradient.
        # The variable should be found.
        y = _MyOnesLike(var)
        return array_ops.identity(x) + 5.0 + y

      input_t = constant_op.constant(2.0)
      result_t = _Func(input_t)
      dependent_vars = custom_gradient.get_dependent_variables(
          [input_t], [result_t])
      self.assertEqual(dependent_vars, [var])
예제 #2
0
  def testNoVariables(self):
    with ops.Graph().as_default():
      func = lambda x: array_ops.identity(x) + 5.0
      input_t = constant_op.constant(2.0)
      result_t = func(input_t)
      dependent_vars = custom_gradient.get_dependent_variables(
          [input_t], [result_t])

      # There are no variables.
      self.assertEqual(dependent_vars, [])
예제 #3
0
  def testVariablesOutsideButDSeparated(self):
    with ops.Graph().as_default():
      init = constant_op.constant(100.0)
      var = variables.Variable(init)

      # The variable is d-separated by the inputs. It should not be found.
      input_t = array_ops.identity(var) * 5.0

      func = lambda x: array_ops.identity(x) + 5.0
      result_t = func(input_t)
      dependent_vars = custom_gradient.get_dependent_variables(
          [input_t], [result_t])
      self.assertEqual(dependent_vars, [])
예제 #4
0
  def testVariablesOutside(self):
    with ops.Graph().as_default():
      init = constant_op.constant(100.0)
      var = variables.Variable(init)

      # The variable is closed over. It should be found.
      func = lambda x: array_ops.identity(x) + 5.0 + var

      input_t = constant_op.constant(2.0)
      result_t = func(input_t)
      dependent_vars = custom_gradient.get_dependent_variables(
          [input_t], [result_t])
      self.assertEqual(dependent_vars, [var])
예제 #5
0
  def testVariableSamePrefix(self):
    with ops.Graph().as_default():
      var_name = "my_variable"
      v_z = variable_scope.get_variable(var_name, shape=())
      v_o = variable_scope.get_variable(var_name + "_ones", shape=())

      # The variable is closed over. It should be found.
      func = lambda x: array_ops.identity(x) + 5.0 + v_z + v_o

      input_t = constant_op.constant(2.0)
      result_t = func(input_t)
      dependent_vars = custom_gradient.get_dependent_variables(
          [input_t], [result_t])
      self.assertEqual(set(dependent_vars), set([v_o, v_z]))
예제 #6
0
  def testVariablesOutsideAndNonDifferentiable(self):
    with ops.Graph().as_default():
      init = constant_op.constant(100.0, shape=(5,))
      var = variables.Variable(init, shape=(5,))

      def _Func(x):
        # non-differentiable dependency on var.
        # the variable should not be found.
        y = array_ops.ones_like(var)
        return array_ops.identity(x) + 5.0 + y

      input_t = constant_op.constant(2.0)
      result_t = _Func(input_t)
      dependent_vars = custom_gradient.get_dependent_variables(
          [input_t], [result_t])
      self.assertEqual(dependent_vars, [])
예제 #7
0
  def testVariablesOutsideAndNonTrainable(self):
    with ops.Graph().as_default():
      init = constant_op.constant(100.0, shape=(5,))

      # Both variables are used in the function but only the trainable one
      # should be found.
      var_trainable = variables.Variable(init, shape=(5,))
      var_nontrainable = variables.Variable(init, shape=(5,), trainable=False)

      def _Func(x):
        del x
        return var_trainable + var_nontrainable

      input_t = constant_op.constant(2.0)
      result_t = _Func(input_t)
      dependent_vars = custom_gradient.get_dependent_variables(
          [input_t], [result_t])
      self.assertEqual(dependent_vars, [var_trainable])
예제 #8
0
  def testNesting(self):
    with ops.Graph().as_default():
      init = constant_op.constant(100.0, shape=(5,))
      var = variables.Variable(init, shape=(5,))

      def _Func(inputs):
        x = inputs["x"]
        result = array_ops.identity(x) + 5.0 + var
        return {
            "y": result
        }

      input_t = constant_op.constant(2.0)
      func_inputs = {
          "x": input_t
      }
      result_t = _Func(func_inputs)

      # Ensure we can deal with dictionary input and output.
      dependent_vars = custom_gradient.get_dependent_variables(
          func_inputs, result_t)
      self.assertEqual(dependent_vars, [var])