Exemplo n.º 1
0
    def testSecondDerivative(self):
        self.skipTest("b/109758172")
        pred = array_ops.placeholder(dtypes.bool, name="pred")
        x = constant_op.constant(3.0, name="x")

        def true_fn():
            return math_ops.pow(x, 3)

        def false_fn():
            return x

        cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond")
        cond_grad = gradients_impl.gradients(cond, [x])
        cond_grad_grad = gradients_impl.gradients(cond_grad, [x])

        with self.test_session() as sess:
            # d[x^3]/dx = 3x^2
            true_val = sess.run(cond_grad, {pred: True})
            self.assertEqual(true_val, [27.0])
            # d[x]/dx = 1
            false_val = sess.run(cond_grad, {pred: False})
            self.assertEqual(false_val, [1.0])

            true_val = sess.run(cond_grad_grad, {pred: True})
            # d2[x^3]/dx2 = 6x
            self.assertEqual(true_val, [18.0])
            false_val = sess.run(cond_grad_grad, {pred: False})
            # d2[x]/dx2 = 0
            self.assertEqual(false_val, [0.0])
Exemplo n.º 2
0
  def testSecondDerivative(self):
    pred = array_ops.placeholder(dtypes.bool, name="pred")
    x = constant_op.constant(3.0, name="x")

    def true_fn():
      return math_ops.pow(x, 3)

    def false_fn():
      return x

    cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond")
    cond_grad = gradients_impl.gradients(cond, [x])
    cond_grad_grad = gradients_impl.gradients(cond_grad, [x])

    with self.test_session() as sess:
      # d[x^3]/dx = 3x^2
      true_val = sess.run(cond_grad, {pred: True})
      self.assertEqual(true_val, [27.0])
      # d[x]/dx = 1
      false_val = sess.run(cond_grad, {pred: False})
      self.assertEqual(false_val, [1.0])

      true_val = sess.run(cond_grad_grad, {pred: True})
      # d2[x^3]/dx2 = 6x
      self.assertEqual(true_val, [18.0])
      false_val = sess.run(cond_grad_grad, {pred: False})
      # d2[x]/dx2 = 0
      self.assertEqual(false_val, [0.0])
Exemplo n.º 3
0
  def testDeviceBeforeCond(self):
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g):
        def fn():
          c = constant_op.constant(3.0)
          self.assertEqual("/device:CPU:0", c.op.device)
          return c

        with ops.device("/device:CPU:0"):
          self.assertEquals(cond_v2.cond_v2(True, fn, fn)[0].eval(), 3)

        def fn2():
          c = constant_op.constant(3.0)
          self.assertEqual("/device:GPU:0", c.op.device)
          return c

        with ops.device("/device:GPU:0"):
          self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3)
Exemplo n.º 4
0
  def _createCond(self, name):
    pred = array_ops.placeholder(dtypes.bool, name="pred")
    x = constant_op.constant(1.0, name="x")

    def true_fn():
      return x

    def false_fn():
      return x + 1

    return cond_v2.cond_v2(pred, true_fn, false_fn, name=name)[0].op
Exemplo n.º 5
0
  def testNoInputs(self):
    pred = array_ops.placeholder(dtypes.bool, name="pred")

    def true_fn():
      return constant_op.constant(1.0)

    def false_fn():
      return constant_op.constant(2.0)

    out = cond_v2.cond_v2(pred, true_fn, false_fn)

    with self.test_session() as sess:
      self.assertEqual(sess.run(out, {pred: True}), [1.0])
      self.assertEqual(sess.run(out, {pred: False}), [2.0])
Exemplo n.º 6
0
  def testColocateWithBeforeCond(self):
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g):

        a = constant_op.constant([2.0], name="a")
        b = constant_op.constant([2.0], name="b")

        def fn():
          c = constant_op.constant(3.0)
          self.assertEqual([b"loc:@a"], c.op.colocation_groups())
          return c

        with ops.colocate_with(a.op):
          self.assertEquals(cond_v2.cond_v2(True, fn, fn)[0].eval(), 3)

        def fn2():
          c = constant_op.constant(3.0)
          self.assertEqual([b"loc:@a", b"loc:@b"], c.op.colocation_groups())
          return c

        with ops.colocate_with(a.op):
          with ops.colocate_with(b.op):
            self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3)
Exemplo n.º 7
0
  def testCollectionIntValueAccessInCond(self):
    """Read values from graph collections inside of cond_v2."""
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g):
        x = 2
        y = 5
        ops.add_to_collection("x", x)
        ops.add_to_collection("y", y)
        def fn():
          x_const = constant_op.constant(ops.get_collection("x")[0])
          y_const = constant_op.constant(ops.get_collection("y")[0])
          return math_ops.add(x_const, y_const)

        cnd = cond_v2.cond_v2(True, fn, fn)
        self.assertEquals(cnd[0].eval(), 7)
Exemplo n.º 8
0
  def testCollectionTensorValueAccessInCond(self):
    """Read tensors from collections inside of cond_v2 & use them."""
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g):
        x = constant_op.constant(2)
        y = constant_op.constant(5)
        ops.add_to_collection("x", x)
        ops.add_to_collection("y", y)

        def fn():
          x_read = ops.get_collection("x")[0]
          y_read = ops.get_collection("y")[0]
          return math_ops.add(x_read, y_read)

        cnd = cond_v2.cond_v2(math_ops.less(x, y), fn, fn)
        self.assertEquals(cnd[0].eval(), 7)
Exemplo n.º 9
0
  def _testCond(self, true_fn, false_fn, train_vals):
    pred = array_ops.placeholder(dtypes.bool, name="pred")

    expected = control_flow_ops.cond(pred, true_fn, false_fn, name="expected")
    actual = cond_v2.cond_v2(pred, true_fn, false_fn, name="actual")

    expected_grad = gradients_impl.gradients(expected, train_vals)
    actual_grad = gradients_impl.gradients(actual, train_vals)

    with self.test_session() as sess:
      expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run(
          (expected, actual, expected_grad, actual_grad), {pred: True})
      self.assertEqual(expected_val, actual_val)
      self.assertEqual(expected_grad_val, actual_grad_val)

      expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run(
          (expected, actual, expected_grad, actual_grad), {pred: False})
      self.assertEqual(expected_val, actual_val)
      self.assertEqual(expected_grad_val, actual_grad_val)
Exemplo n.º 10
0
  def _testCond(self, true_fn, false_fn, train_vals):
    pred = array_ops.placeholder(dtypes.bool, name="pred")

    expected = control_flow_ops.cond(pred, true_fn, false_fn, name="expected")
    actual = cond_v2.cond_v2(pred, true_fn, false_fn, name="actual")

    expected_grad = gradients_impl.gradients(expected, train_vals)
    actual_grad = gradients_impl.gradients(actual, train_vals)

    with self.test_session() as sess:
      expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run(
          (expected, actual, expected_grad, actual_grad), {pred: True})
      self.assertEqual(expected_val, actual_val)
      self.assertEqual(expected_grad_val, actual_grad_val)

      expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run(
          (expected, actual, expected_grad, actual_grad), {pred: False})
      self.assertEqual(expected_val, actual_val)
      self.assertEqual(expected_grad_val, actual_grad_val)
Exemplo n.º 11
0
  def testGradientOfDeserializedCond(self):
    with ops.Graph().as_default():
      pred = array_ops.placeholder(dtypes.bool, name="pred")
      x = constant_op.constant(3.0, name="x")
      ops.add_to_collection("x", x)

      def true_fn():
        return math_ops.pow(x, 3)

      def false_fn():
        return x

      ops.add_to_collection("pred", pred)
      cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond")
      for c in cond:
        ops.add_to_collection("cond", c)
      meta_graph = saver.export_meta_graph()

    with ops.Graph().as_default() as g:
      saver.import_meta_graph(meta_graph)
      x = ops.get_collection("x")[0]
      pred = ops.get_collection("pred")[0]
      cond = ops.get_collection("cond")
      cond_grad = gradients_impl.gradients(cond, [x], name="cond_grad")
      cond_grad_grad = gradients_impl.gradients(
          cond_grad, [x], name="cond_grad_grad")
      with self.test_session(graph=g) as sess:
        # d[x^3]/dx = 3x^2
        true_val = sess.run(cond_grad, {pred: True})
        self.assertEqual(true_val, [27.0])
        # d[x]/dx = 1
        false_val = sess.run(cond_grad, {pred: False})
        self.assertEqual(false_val, [1.0])

        true_val = sess.run(cond_grad_grad, {pred: True})
        # d2[x^3]/dx2 = 6x
        self.assertEqual(true_val, [18.0])
        false_val = sess.run(cond_grad_grad, {pred: False})
        # d2[x]/dx2 = 0
        self.assertEqual(false_val, [0.0])
Exemplo n.º 12
0
    def testGradientOfDeserializedCond(self):
        with ops.Graph().as_default():
            pred = array_ops.placeholder(dtypes.bool, name="pred")
            x = constant_op.constant(3.0, name="x")
            ops.add_to_collection("x", x)

            def true_fn():
                return math_ops.pow(x, 3)

            def false_fn():
                return x

            ops.add_to_collection("pred", pred)
            cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond")
            for c in cond:
                ops.add_to_collection("cond", c)
            meta_graph = saver.export_meta_graph()

        with ops.Graph().as_default() as g:
            saver.import_meta_graph(meta_graph)
            x = ops.get_collection("x")[0]
            pred = ops.get_collection("pred")[0]
            cond = ops.get_collection("cond")
            cond_grad = gradients_impl.gradients(cond, [x], name="cond_grad")
            cond_grad_grad = gradients_impl.gradients(cond_grad, [x],
                                                      name="cond_grad_grad")
            with self.test_session(graph=g) as sess:
                # d[x^3]/dx = 3x^2
                true_val = sess.run(cond_grad, {pred: True})
                self.assertEqual(true_val, [27.0])
                # d[x]/dx = 1
                false_val = sess.run(cond_grad, {pred: False})
                self.assertEqual(false_val, [1.0])

                true_val = sess.run(cond_grad_grad, {pred: True})
                # d2[x^3]/dx2 = 6x
                self.assertEqual(true_val, [18.0])
                false_val = sess.run(cond_grad_grad, {pred: False})
                # d2[x]/dx2 = 0
                self.assertEqual(false_val, [0.0])
Exemplo n.º 13
0
  def testCollectionIntValueWriteInCond(self):
    """Make sure Int writes to collections work inside of cond_v2."""
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g):
        x = constant_op.constant(2)
        y = constant_op.constant(5)
        def true_fn():
          z = math_ops.add(x, y)
          ops.add_to_collection("z", 7)
          return math_ops.mul(x, z)

        def false_fn():
          z = math_ops.add(x, y)
          return math_ops.mul(x, z)

        cnd = cond_v2.cond_v2(
            True, true_fn,
            false_fn)
        self.assertEquals(cnd[0].eval(), 14)

        read_z_collection = ops.get_collection("z")
        self.assertEquals(read_z_collection, [7])
Exemplo n.º 14
0
  def testContainer(self):
    """Set containers outside & inside of cond_v2.

    Make sure the containers are set correctly for both variable creation
    (tested by variables.Variable) and for stateful ops (tested by FIFOQueue)
    """
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g):

        v0 = variables.Variable([0])
        q0 = data_flow_ops.FIFOQueue(1, dtypes.float32)

        def container(node):
          return node.op.get_attr("container")

        self.assertEqual(compat.as_bytes(""), container(v0))
        self.assertEqual(compat.as_bytes(""), container(q0.queue_ref))

        def true_fn():
          # When this branch is created in cond below,
          # the container should begin with 'l1'
          v1 = variables.Variable([1])
          q1 = data_flow_ops.FIFOQueue(1, dtypes.float32)

          with ops.container("l2t"):
            v2 = variables.Variable([2])
            q2 = data_flow_ops.FIFOQueue(1, dtypes.float32)

          v3 = variables.Variable([1])
          q3 = data_flow_ops.FIFOQueue(1, dtypes.float32)

          self.assertEqual(compat.as_bytes("l1"), container(v1))
          self.assertEqual(compat.as_bytes("l1"), container(q1.queue_ref))
          self.assertEqual(compat.as_bytes("l2t"), container(v2))
          self.assertEqual(compat.as_bytes("l2t"), container(q2.queue_ref))
          self.assertEqual(compat.as_bytes("l1"), container(v3))
          self.assertEqual(compat.as_bytes("l1"), container(q3.queue_ref))

          return constant_op.constant(2.0)

        def false_fn():
          # When this branch is created in cond below,
          # the container should begin with 'l1'
          v1 = variables.Variable([1])
          q1 = data_flow_ops.FIFOQueue(1, dtypes.float32)

          with ops.container("l2f"):
            v2 = variables.Variable([2])
            q2 = data_flow_ops.FIFOQueue(1, dtypes.float32)

          v3 = variables.Variable([1])
          q3 = data_flow_ops.FIFOQueue(1, dtypes.float32)

          self.assertEqual(compat.as_bytes("l1"), container(v1))
          self.assertEqual(compat.as_bytes("l1"), container(q1.queue_ref))
          self.assertEqual(compat.as_bytes("l2f"), container(v2))
          self.assertEqual(compat.as_bytes("l2f"), container(q2.queue_ref))
          self.assertEqual(compat.as_bytes("l1"), container(v3))
          self.assertEqual(compat.as_bytes("l1"), container(q3.queue_ref))

          return constant_op.constant(6.0)

        with ops.container("l1"):
          cnd_true = cond_v2.cond_v2(True, true_fn, false_fn)
          self.assertEquals(cnd_true[0].eval(), 2)

          cnd_false = cond_v2.cond_v2(False, true_fn, false_fn)
          self.assertEquals(cnd_false[0].eval(), 6)

          v4 = variables.Variable([3])
          q4 = data_flow_ops.FIFOQueue(1, dtypes.float32)
        v5 = variables.Variable([4])
        q5 = data_flow_ops.FIFOQueue(1, dtypes.float32)

      self.assertEqual(compat.as_bytes("l1"), container(v4))
      self.assertEqual(compat.as_bytes("l1"), container(q4.queue_ref))
      self.assertEqual(compat.as_bytes(""), container(v5))
      self.assertEqual(compat.as_bytes(""), container(q5.queue_ref))