示例#1
0
  def testDeviceBeforeCond(self):
    with ops.Graph().as_default() as g:
      with self.session(graph=g):

        def fn():
          self.assertEqual("", constant_op.constant(3.0).op.device)
          return test_ops.device_placement_op()

        with ops.device("/device:CPU:0"):
          self.assertIn(
              compat.as_bytes("CPU:0"),
              self.evaluate(cond_v2.cond_v2(constant_op.constant(True),
                                            fn, fn)))

        def fn2():
          self.assertEqual("", constant_op.constant(3.0).op.device)
          return test_ops.device_placement_op()

        if test_util.is_gpu_available():
          with ops.device("/device:GPU:0"):
            self.assertIn(
                compat.as_bytes("GPU:0"),
                self.evaluate(cond_v2.cond_v2(constant_op.constant(True),
                                              fn2, fn2)))
        else:
          self.skipTest("Test requires a GPU to check GPU device placement.")
示例#2
0
  def testColocateWithBeforeCond(self):
    with ops.Graph().as_default() as g:
      with self.session(graph=g):

        a = constant_op.constant([2.0], name="a")
        b = constant_op.constant([2.0], name="b")

        def fn():
          c = constant_op.constant(3.0)
          self.assertEqual([b"loc:@a"], c.op.colocation_groups())
          return c

        with ops.colocate_with(a.op):
          self.assertEquals(
              cond_v2.cond_v2(constant_op.constant(True), fn, fn).eval(), 3)

        def fn2():
          c = constant_op.constant(3.0)
          self.assertEqual([b"loc:@a", b"loc:@b"], c.op.colocation_groups())
          return c

        with ops.colocate_with(a.op):
          with ops.colocate_with(b.op):
            self.assertEquals(
                cond_v2.cond_v2(constant_op.constant(True), fn2, fn2).eval(), 3)
示例#3
0
  def testExternalControlDependencies(self):
    with ops.Graph().as_default(), self.test_session():
      v = variables.Variable(1.0)
      v.initializer.run()
      op = v.assign_add(1.0)

      def true_branch():
        with ops.control_dependencies([op]):
          return 1.0

      cond_v2.cond_v2(array_ops.placeholder_with_default(False, None),
                      true_branch,
                      lambda: 2.0).eval()
      self.assertAllEqual(self.evaluate(v), 2.0)
示例#4
0
    def func_with_cond():
      pred = constant_op.constant(True, name="pred")
      x = constant_op.constant(1.0, name="x")

      def true_fn():
        intermediate = x + 1
        return intermediate * x

      def false_fn():
        return x + 1

      output = cond_v2.cond_v2(pred, true_fn, false_fn)
      grad = gradients_impl.gradients(output, x)[0]

      forward_if_op = output.op.inputs[0].op
      gradient_if_op = grad.op.inputs[0].op

      def verify_no_optional_ops(op, branch_name):
        branch_function = ops.get_default_graph()._get_function(
            op.get_attr(branch_name).name)
        function_def = branch_function.definition
        for node_def in function_def.node_def:
          self.assertNotIn(node_def.op, _OPTIONAL_OPS)

      verify_no_optional_ops(forward_if_op, "then_branch")
      verify_no_optional_ops(forward_if_op, "else_branch")
      verify_no_optional_ops(gradient_if_op, "then_branch")
      verify_no_optional_ops(gradient_if_op, "else_branch")

      return grad
示例#5
0
  def testSecondDerivative(self):
    with self.test_session() as sess:
      pred = array_ops.placeholder(dtypes.bool, name="pred")
      x = constant_op.constant(3.0, name="x")

      def true_fn():
        return math_ops.pow(x, 3)

      def false_fn():
        return x

      cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond")
      cond_grad = gradients_impl.gradients(cond, [x])
      cond_grad_grad = gradients_impl.gradients(cond_grad, [x])

      # d[x^3]/dx = 3x^2
      true_val = sess.run(cond_grad, {pred: True})
      self.assertEqual(true_val, [27.0])
      # d[x]/dx = 1
      false_val = sess.run(cond_grad, {pred: False})
      self.assertEqual(false_val, [1.0])

      true_val = sess.run(cond_grad_grad, {pred: True})
      # d2[x^3]/dx2 = 6x
      self.assertEqual(true_val, [18.0])
      false_val = sess.run(cond_grad_grad, {pred: False})
      # d2[x]/dx2 = 0
      self.assertEqual(false_val, [0.0])
示例#6
0
  def _testCond(self, true_fn, false_fn, train_vals, feed_dict=None):
    if not feed_dict:
      feed_dict = {}
    with self.test_session(graph=ops.get_default_graph()) as sess:
      pred = array_ops.placeholder(dtypes.bool, name="pred")

      expected = control_flow_ops.cond(pred, true_fn, false_fn, name="expected")
      actual = cond_v2.cond_v2(pred, true_fn, false_fn, name="actual")

      expected_grad = gradients_impl.gradients(expected, train_vals)
      actual_grad = gradients_impl.gradients(actual, train_vals)

      sess_run_args = {pred: True}
      sess_run_args.update(feed_dict)
      expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run(
          (expected, actual, expected_grad, actual_grad), sess_run_args)
      self.assertEqual(expected_val, actual_val)
      self.assertEqual(expected_grad_val, actual_grad_val)

      sess_run_args = {pred: False}
      sess_run_args.update(feed_dict)
      expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run(
          (expected, actual, expected_grad, actual_grad), sess_run_args)
      self.assertEqual(expected_val, actual_val)
      self.assertEqual(expected_grad_val, actual_grad_val)
示例#7
0
    def build_graph():
      pred_outer = array_ops.placeholder(dtypes.bool, name="pred_outer")
      pred_inner = array_ops.placeholder(dtypes.bool, name="pred_inner")
      x = constant_op.constant(1.0, name="x")
      y = constant_op.constant(2.0, name="y")

      def true_fn():
        return 2.0

      def false_fn():

        def inner_true_fn():
          return x * y * 2.0

        def inner_false_fn():
          return x * 5.0

        return cond_v2.cond_v2(
            pred_inner, inner_true_fn, inner_false_fn, name="inner_cond")

      cond_outer = cond_v2.cond_v2(
          pred_outer, true_fn, false_fn, name="outer_cond")

      # Compute grads inside a Defun.
      @function.defun
      def nesting_fn():
        return gradients_impl.gradients(cond_outer, [x, y])

      grads = nesting_fn()

      return grads, pred_outer, pred_inner
示例#8
0
  def testColocateWithInCondGraphPartitioning(self):
    with ops.Graph().as_default() as g:
      with self.test_session(
          graph=g,
          config=config_pb2.ConfigProto(device_count={"CPU": 2})
      ) as sess:

        with ops.device("/device:CPU:0"):
          a = constant_op.constant([2.0], name="a")
        with ops.device("/device:CPU:1"):
          b = constant_op.constant([2.0], name="b")

        def fn():
          with ops.colocate_with(b.op):
            c = math_ops.add(a, a, name="c")
          return c
        out_cond_2 = cond_v2.cond_v2(True, fn, fn)[0]

        run_options = config_pb2.RunOptions(output_partition_graphs=True)
        run_metadata = config_pb2.RunMetadata()
        sess.run(out_cond_2, options=run_options, run_metadata=run_metadata)

        # We expect there to be two partitions because of the
        # colocate_with. We are only running the cond, which has a data
        # dependency on `a` but not on `b`. So, without the colocate_with
        # we would expect execution on just one device.
        self.assertTrue(len(run_metadata.partition_graphs) >= 2)
示例#9
0
  def testDeviceBeforeCond(self):
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g):
        def fn():
          c = constant_op.constant(3.0)
          self.assertEqual("/device:CPU:0", c.op.device)
          return c

        with ops.device("/device:CPU:0"):
          self.assertEquals(cond_v2.cond_v2(True, fn, fn)[0].eval(), 3)

        def fn2():
          c = constant_op.constant(3.0)
          self.assertEqual("/device:GPU:0", c.op.device)
          return c

        with ops.device("/device:GPU:0"):
          self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3)
示例#10
0
        def false_fn():

          def inner_true_fn():
            return x * y * 2.0

          def inner_false_fn():
            return x * 5.0

          return cond_v2.cond_v2(
              pred_inner, inner_true_fn, inner_false_fn, name="inner_cond")
示例#11
0
  def _createCond(self, name):
    pred = constant_op.constant(True, name="pred")
    x = constant_op.constant(1.0, name="x")

    def true_fn():
      return x

    def false_fn():
      return x + 1

    return cond_v2.cond_v2(pred, true_fn, false_fn, name=name)[0].op
示例#12
0
      def fnWithCond():  # pylint: disable=invalid-name
        with backprop.GradientTape() as tape:
          pred = constant_op.constant(True, dtype=dtypes.bool)

          def true_fn():
            return math_ops.pow(v, 3)

          def false_fn():
            return v

          cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond")
        return tape.gradient(cond, v)
示例#13
0
  def testNoInputs(self):
    with self.test_session() as sess:
      pred = array_ops.placeholder(dtypes.bool, name="pred")

      def true_fn():
        return constant_op.constant(1.0)

      def false_fn():
        return constant_op.constant(2.0)

      out = cond_v2.cond_v2(pred, true_fn, false_fn)

      self.assertEqual(sess.run(out, {pred: True}), (1.0,))
      self.assertEqual(sess.run(out, {pred: False}), (2.0,))
示例#14
0
  def _createCond(self, name):
    """Helper function for testDefaultName."""
    pred = constant_op.constant(True, name="pred")
    x = constant_op.constant(1.0, name="x")

    def true_fn():
      return x

    def false_fn():
      return x + 1

    output = cond_v2.cond_v2(pred, true_fn, false_fn, name=name)
    cond_op = output.op.inputs[0].op
    self.assertEqual(cond_op.type, "If")
    return cond_op
示例#15
0
  def testCollectionIntValueAccessInCond(self):
    """Read values from graph collections inside of cond_v2."""
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g):
        x = 2
        y = 5
        ops.add_to_collection("x", x)
        ops.add_to_collection("y", y)
        def fn():
          x_const = constant_op.constant(ops.get_collection("x")[0])
          y_const = constant_op.constant(ops.get_collection("y")[0])
          return math_ops.add(x_const, y_const)

        cnd = cond_v2.cond_v2(True, fn, fn)
        self.assertEquals(cnd[0].eval(), 7)
示例#16
0
  def _createNestedCond(self, name):
    """Like _createCond but creates a nested cond_v2 call as well."""
    pred = constant_op.constant(True, name="pred")
    x = constant_op.constant(1.0, name="x")

    def true_fn():
      return cond_v2.cond_v2(pred, lambda: x, lambda: x + 1)

    def false_fn():
      return x + 2

    output = cond_v2.cond_v2(pred, true_fn, false_fn, name=name)
    cond_op = output.op.inputs[0].op
    self.assertEqual(cond_op.type, "If")
    return output, cond_op
示例#17
0
  def _createCond(self, name):
    """Creates a cond_v2 call and returns the output tensor and the cond op."""
    pred = constant_op.constant(True, name="pred")
    x = constant_op.constant(1.0, name="x")

    def true_fn():
      return x

    def false_fn():
      return x + 1

    output = cond_v2.cond_v2(pred, true_fn, false_fn, name=name)
    cond_op = output.op.inputs[0].op
    self.assertEqual(cond_op.type, "If")
    return output, cond_op
示例#18
0
    def testCollectionIntValueAccessInCond(self):
        """Read values from graph collections inside of cond_v2."""
        with ops.Graph().as_default() as g:
            with self.session(graph=g):
                x = 2
                y = 5
                ops.add_to_collection("x", x)
                ops.add_to_collection("y", y)

                def fn():
                    x_const = constant_op.constant(ops.get_collection("x")[0])
                    y_const = constant_op.constant(ops.get_collection("y")[0])
                    return math_ops.add(x_const, y_const)

                cnd = cond_v2.cond_v2(constant_op.constant(True), fn, fn)
                self.assertEquals(cnd.eval(), 7)
示例#19
0
    def testCollectionTensorValueAccessInCond(self):
        """Read tensors from collections inside of cond_v2 & use them."""
        with ops.Graph().as_default() as g:
            with self.session(graph=g):
                x = constant_op.constant(2)
                y = constant_op.constant(5)
                ops.add_to_collection("x", x)
                ops.add_to_collection("y", y)

                def fn():
                    x_read = ops.get_collection("x")[0]
                    y_read = ops.get_collection("y")[0]
                    return math_ops.add(x_read, y_read)

                cnd = cond_v2.cond_v2(math_ops.less(x, y), fn, fn)
                self.assertEquals(cnd.eval(), 7)
示例#20
0
    def testForwardPassRewrite(self):
        x = constant_op.constant(1.0, name="x")
        output = cond_v2.cond_v2(constant_op.constant(True), lambda: x * 2.0,
                                 lambda: x)
        if_op = output.op.inputs[0].op
        self.assertEqual(if_op.type, "If")
        # pylint: disable=g-deprecated-assert
        self.assertEqual(len(if_op.outputs), 1)

        gradients_impl.gradients(output, x)
        # if_op should have been rewritten to output 2.0 intermediate.
        self.assertEqual(len(if_op.outputs), 2)

        gradients_impl.gradients(output, x)
        # Computing the gradient again shouldn't rewrite if_op again.
        self.assertEqual(len(if_op.outputs), 2)
示例#21
0
  def testDeviceInAndOutOfCond(self):
    with ops.Graph().as_default() as g:
      with self.test_session(
          graph=g, config=config_pb2.ConfigProto(device_count={"CPU": 2})):

        def fn2():
          with ops.device("/device:CPU:1"):
            c = constant_op.constant(3.0)
            self.assertEqual("/device:CPU:1", c.op.device)
            return c

        with ops.device("/device:CPU:0"):
          self.assertEquals(cond_v2.cond_v2(True, fn2, fn2).eval(), 3)

          d = constant_op.constant(4.0)
          self.assertEqual("/device:CPU:0", d.op.device)
示例#22
0
    def testDeviceInAndOutOfCond(self):
        with ops.Graph().as_default() as g:
            with self.test_session(graph=g):

                def fn2():
                    with ops.device("/device:GPU:0"):
                        c = constant_op.constant(3.0)
                        self.assertEqual("/device:GPU:0", c.op.device)
                        return c

                with ops.device("/device:CPU:0"):
                    self.assertEquals(
                        cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3)

                    d = constant_op.constant(4.0)
                    self.assertEqual("/device:CPU:0", d.op.device)
示例#23
0
  def testCollectionTensorValueAccessInCond(self):
    """Read tensors from collections inside of cond_v2 & use them."""
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g):
        x = constant_op.constant(2)
        y = constant_op.constant(5)
        ops.add_to_collection("x", x)
        ops.add_to_collection("y", y)

        def fn():
          x_read = ops.get_collection("x")[0]
          y_read = ops.get_collection("y")[0]
          return math_ops.add(x_read, y_read)

        cnd = cond_v2.cond_v2(math_ops.less(x, y), fn, fn)
        self.assertEquals(cnd[0].eval(), 7)
示例#24
0
  def testDeviceInAndOutOfCond(self):
    with ops.Graph().as_default() as g:
      with self.session(
          graph=g, config=config_pb2.ConfigProto(device_count={"CPU": 2})):

        def fn2():
          with ops.device("/device:CPU:1"):
            c = constant_op.constant(3.0)
            self.assertEqual("/device:CPU:1", c.op.device)
            return c

        with ops.device("/device:CPU:0"):
          self.assertEquals(
              cond_v2.cond_v2(constant_op.constant(True), fn2, fn2).eval(), 3)

          d = constant_op.constant(4.0)
          self.assertEqual("/device:CPU:0", d.op.device)
示例#25
0
  def testForwardPassRewrite(self):
    x = constant_op.constant(1.0, name="x")
    output = cond_v2.cond_v2(constant_op.constant(True),
                             lambda: x * 2.0,
                             lambda: x)
    if_op = output.op.inputs[0].op
    self.assertEqual(if_op.type, "If")
    # pylint: disable=g-deprecated-assert
    self.assertEqual(len(if_op.outputs), 1)

    gradients_impl.gradients(output, x)
    # if_op should have been rewritten to output 2.0 intermediate.
    self.assertEqual(len(if_op.outputs), 2)

    gradients_impl.gradients(output, x)
    # Computing the gradient again shouldn't rewrite if_op again.
    self.assertEqual(len(if_op.outputs), 2)
示例#26
0
  def testDoNotAccumulateConstants(self):
    x = constant_op.constant(1.0, name="x")
    output = cond_v2.cond_v2(
        constant_op.constant(True), lambda: x * 2.0, lambda: x)
    if_op = output.op.inputs[0].op
    self.assertEqual(if_op.type, "StatelessIf")
    # pylint: disable=g-deprecated-assert
    self.assertEqual(len(if_op.outputs), 1)

    gradients_impl.gradients(output, x)
    # Number of outputs does change because
    # 1. `x` is a loop input so does not need to be accumulated.
    # 2. 2.0 is a constant so it is not accumulated.
    self.assertEqual(len(if_op.outputs), 1)

    gradients_impl.gradients(output, x)
    # Computing the gradient again shouldn't rewrite if_op again.
    self.assertEqual(len(if_op.outputs), 1)
示例#27
0
  def testColocateWithInAndOutOfCond(self):
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g):

        a = constant_op.constant([2.0], name="a")
        b = constant_op.constant([2.0], name="b")

        def fn2():
          with ops.colocate_with(b.op):
            c = constant_op.constant(3.0)
            self.assertEqual([b"loc:@a", b"loc:@b"], c.op.colocation_groups())
            return c

        with ops.colocate_with(a.op):
          self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3)

          d = constant_op.constant([2.0], name="d")
          self.assertEqual([b"loc:@a"], d.op.colocation_groups())
示例#28
0
      def fn():

        def true_fn():
          return 2.0

        def false_fn():

          def inner_true_fn():
            return x * y * 2.0

          def inner_false_fn():
            return x * 5.0

          return cond_v2.cond_v2(
              pred_inner, inner_true_fn, inner_false_fn, name="inner_cond")

        cond_outer = cond_v2.cond_v2(
            pred_outer, true_fn, false_fn, name="outer_cond")
        return gradients_impl.gradients(cond_outer, [x, y])
示例#29
0
      def fn():

        def true_fn():
          return 2.0

        def false_fn():

          def inner_true_fn():
            return x * y * 2.0

          def inner_false_fn():
            return x * 5.0

          return cond_v2.cond_v2(
              pred_inner, inner_true_fn, inner_false_fn, name="inner_cond")

        cond_outer = cond_v2.cond_v2(
            pred_outer, true_fn, false_fn, name="outer_cond")
        return gradients_impl.gradients(cond_outer, [x, y])
示例#30
0
  def _testCond(self, true_fn, false_fn, train_vals):
    with self.test_session() as sess:
      pred = array_ops.placeholder(dtypes.bool, name="pred")

      expected = control_flow_ops.cond(pred, true_fn, false_fn, name="expected")
      actual = cond_v2.cond_v2(pred, true_fn, false_fn, name="actual")

      expected_grad = gradients_impl.gradients(expected, train_vals)
      actual_grad = gradients_impl.gradients(actual, train_vals)

      expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run(
          (expected, actual, expected_grad, actual_grad), {pred: True})
      self.assertEqual(expected_val, actual_val)
      self.assertEqual(expected_grad_val, actual_grad_val)

      expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run(
          (expected, actual, expected_grad, actual_grad), {pred: False})
      self.assertEqual(expected_val, actual_val)
      self.assertEqual(expected_grad_val, actual_grad_val)
示例#31
0
  def testCollectionIntValueWriteInCond(self):
    """Make sure Int writes to collections work inside of cond_v2."""
    with ops.Graph().as_default() as g:
      with self.session(graph=g):
        x = constant_op.constant(2)
        y = constant_op.constant(5)
        def true_fn():
          z = math_ops.add(x, y)
          ops.add_to_collection("z", 7)
          return math_ops.mul(x, z)

        def false_fn():
          z = math_ops.add(x, y)
          return math_ops.mul(x, z)

        cnd = cond_v2.cond_v2(constant_op.constant(True), true_fn, false_fn)
        self.assertEquals(cnd.eval(), 14)

        read_z_collection = ops.get_collection("z")
        self.assertEquals(read_z_collection, [7])
示例#32
0
  def testCollectionIntValueWriteInCond(self):
    """Make sure Int writes to collections work inside of cond_v2."""
    with ops.Graph().as_default() as g:
      with self.session(graph=g):
        x = constant_op.constant(2)
        y = constant_op.constant(5)
        def true_fn():
          z = math_ops.add(x, y)
          ops.add_to_collection("z", 7)
          return math_ops.mul(x, z)

        def false_fn():
          z = math_ops.add(x, y)
          return math_ops.mul(x, z)

        cnd = cond_v2.cond_v2(constant_op.constant(True), true_fn, false_fn)
        self.assertEquals(cnd.eval(), 14)

        read_z_collection = ops.get_collection("z")
        self.assertEquals(read_z_collection, [7])
示例#33
0
    def testGradientOfDeserializedCond(self):
        with ops.Graph().as_default():
            pred = array_ops.placeholder(dtypes.bool, name="pred")
            x = constant_op.constant(3.0, name="x")
            ops.add_to_collection("x", x)

            def true_fn():
                return math_ops.pow(x, 3)

            def false_fn():
                return x

            ops.add_to_collection("pred", pred)
            cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond")
            for c in cond:
                ops.add_to_collection("cond", c)
            meta_graph = saver.export_meta_graph()

        with ops.Graph().as_default() as g:
            with self.test_session(graph=g) as sess:
                saver.import_meta_graph(meta_graph)
                x = ops.get_collection("x")[0]
                pred = ops.get_collection("pred")[0]
                cond = ops.get_collection("cond")
                cond_grad = gradients_impl.gradients(cond, [x],
                                                     name="cond_grad")
                cond_grad_grad = gradients_impl.gradients(
                    cond_grad, [x], name="cond_grad_grad")
                # d[x^3]/dx = 3x^2
                true_val = sess.run(cond_grad, {pred: True})
                self.assertEqual(true_val, [27.0])
                # d[x]/dx = 1
                false_val = sess.run(cond_grad, {pred: False})
                self.assertEqual(false_val, [1.0])

                true_val = sess.run(cond_grad_grad, {pred: True})
                # d2[x^3]/dx2 = 6x
                self.assertEqual(true_val, [18.0])
                false_val = sess.run(cond_grad_grad, {pred: False})
                # d2[x]/dx2 = 0
                self.assertEqual(false_val, [0.0])
示例#34
0
  def testForwardPassRewrite(self):
    x = constant_op.constant(1.0, name="x")
    y = constant_op.constant(1.0, name="y")

    def true_fn():
      y_plus_one = y + 1.
      return x * y_plus_one

    output = cond_v2.cond_v2(constant_op.constant(True), true_fn, lambda: x)
    if_op = output.op.inputs[0].op
    self.assertEqual(if_op.type, "StatelessIf")
    # pylint: disable=g-deprecated-assert
    self.assertEqual(len(if_op.outputs), 1)

    gradients_impl.gradients(output, x)
    # if_op should have been rewritten to output `y_plus_one`.
    self.assertEqual(len(if_op.outputs), 2)

    gradients_impl.gradients(output, x)
    # Computing the gradient again shouldn't rewrite if_op again.
    self.assertEqual(len(if_op.outputs), 2)
示例#35
0
  def testDeviceInCondGraphPartitioning(self):
    with ops.Graph().as_default() as g:
      with self.test_session(
          graph=g,
          config=config_pb2.ConfigProto(device_count={"CPU": 2})
      ) as sess:

        def fn():
          with ops.device("/device:CPU:1"):
            c = math_ops.add(a, a, name="c")
          return c

        with ops.device("/device:CPU:0"):
          a = constant_op.constant([2.0], name="a")
          out_cond_2 = cond_v2.cond_v2(True, fn, fn)[0]

        run_options = config_pb2.RunOptions(output_partition_graphs=True)
        run_metadata = config_pb2.RunMetadata()
        sess.run(out_cond_2, options=run_options, run_metadata=run_metadata)

        self.assertTrue(len(run_metadata.partition_graphs) >= 2)
示例#36
0
  def testDeviceInCondGraphPartitioning(self):
    with ops.Graph().as_default() as g:
      with self.session(
          graph=g,
          config=config_pb2.ConfigProto(device_count={"CPU": 2})
      ) as sess:

        def fn():
          with ops.device("/device:CPU:1"):
            c = math_ops.add(a, a, name="c")
          return c

        with ops.device("/device:CPU:0"):
          a = constant_op.constant([2.0], name="a")
          out_cond_2 = cond_v2.cond_v2(constant_op.constant(True), fn, fn)

        run_options = config_pb2.RunOptions(output_partition_graphs=True)
        run_metadata = config_pb2.RunMetadata()
        sess.run(out_cond_2, options=run_options, run_metadata=run_metadata)

        self.assertTrue(len(run_metadata.partition_graphs) >= 2)
示例#37
0
  def _testCond(self, true_fn, false_fn, train_vals, feed_dict=None):
    if not feed_dict:
      feed_dict = {}
    with self.session(graph=ops.get_default_graph()) as sess:
      pred = array_ops.placeholder(dtypes.bool, name="pred")

      expected = control_flow_ops.cond(
          array_ops.squeeze_v2(pred), true_fn, false_fn, name="expected")
      actual = cond_v2.cond_v2(pred, true_fn, false_fn, name="actual")

      expected_grad = gradients_impl.gradients(expected, train_vals)
      actual_grad = gradients_impl.gradients(actual, train_vals)

      sess_run_args = {pred: True}
      sess_run_args.update(feed_dict)
      expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run(
          (expected, actual, expected_grad, actual_grad), sess_run_args)
      self.assertEqual(expected_val, actual_val)
      self.assertEqual(expected_grad_val, actual_grad_val)

      sess_run_args = {pred: [[True]]}
      sess_run_args.update(feed_dict)
      expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run(
          (expected, actual, expected_grad, actual_grad), sess_run_args)
      self.assertEqual(expected_val, actual_val)
      self.assertEqual(expected_grad_val, actual_grad_val)

      sess_run_args = {pred: False}
      sess_run_args.update(feed_dict)
      expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run(
          (expected, actual, expected_grad, actual_grad), sess_run_args)
      self.assertEqual(expected_val, actual_val)
      self.assertEqual(expected_grad_val, actual_grad_val)

      sess_run_args = {pred: [[False]]}
      sess_run_args.update(feed_dict)
      expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run(
          (expected, actual, expected_grad, actual_grad), sess_run_args)
      self.assertEqual(expected_val, actual_val)
      self.assertEqual(expected_grad_val, actual_grad_val)
示例#38
0
  def testGradientOfDeserializedCond(self):
    with ops.Graph().as_default():
      pred = array_ops.placeholder(dtypes.bool, name="pred")
      x = constant_op.constant(3.0, name="x")
      ops.add_to_collection("x", x)

      def true_fn():
        return math_ops.pow(x, 3)

      def false_fn():
        return x

      ops.add_to_collection("pred", pred)
      cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond")
      for c in cond:
        ops.add_to_collection("cond", c)
      meta_graph = saver.export_meta_graph()

    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as sess:
        saver.import_meta_graph(meta_graph)
        x = ops.get_collection("x")[0]
        pred = ops.get_collection("pred")[0]
        cond = ops.get_collection("cond")
        cond_grad = gradients_impl.gradients(cond, [x], name="cond_grad")
        cond_grad_grad = gradients_impl.gradients(
            cond_grad, [x], name="cond_grad_grad")
        # d[x^3]/dx = 3x^2
        true_val = sess.run(cond_grad, {pred: True})
        self.assertEqual(true_val, [27.0])
        # d[x]/dx = 1
        false_val = sess.run(cond_grad, {pred: False})
        self.assertEqual(false_val, [1.0])

        true_val = sess.run(cond_grad_grad, {pred: True})
        # d2[x^3]/dx2 = 6x
        self.assertEqual(true_val, [18.0])
        false_val = sess.run(cond_grad_grad, {pred: False})
        # d2[x]/dx2 = 0
        self.assertEqual(false_val, [0.0])
示例#39
0
        def build_graph():
            pred_outer = array_ops.placeholder(dtypes.bool, name="pred_outer")
            pred_inner = array_ops.placeholder(dtypes.bool, name="pred_inner")
            x = constant_op.constant(1.0, name="x")
            y = constant_op.constant(2.0, name="y")

            def true_fn():
                return 2.0

            def false_fn():
                def inner_true_fn():
                    return x * y * 2.0

                def inner_false_fn():
                    return x * 5.0

                return cond_v2.cond_v2(pred_inner,
                                       inner_true_fn,
                                       inner_false_fn,
                                       name="inner_cond")

            cond_outer = cond_v2.cond_v2(pred_outer,
                                         true_fn,
                                         false_fn,
                                         name="outer_cond")

            # Compute grads inside a Defun.
            @function.defun
            def nesting_fn():
                @function.defun
                def inner_nesting_fn():
                    return gradients_impl.gradients(cond_outer, [x, y])

                return inner_nesting_fn()

            grads = nesting_fn()

            return grads, pred_outer, pred_inner
示例#40
0
    def fn_with_cond():

      def update_v1():
        v1.assign(v1)
        return v1

      def update_v2():
        v2.assign(v2)
        return v2

      cond_v2.cond_v2(
          constant_op.constant(True),
          update_v1,
          lambda: constant_op.constant(0.),
          name="cond_1")
      cond_2 = cond_v2.cond_v2(
          constant_op.constant(False),
          lambda: constant_op.constant(0.),
          update_v1,
          name="cond_2")
      cond_v2.cond_v2(
          constant_op.constant(True),
          update_v2,
          lambda: constant_op.constant(0.),
          name="cond_3")

      @def_function.function
      def cond_4_false_branch():
        v2.assign(v2)
        return v2

      cond_4 = cond_v2.cond_v2(
          constant_op.constant(False),
          lambda: constant_op.constant(0.),
          cond_4_false_branch,
          name="cond_4")
      return cond_2, cond_4
 def recursive_fn(n):
     return cond_v2.cond_v2(n > 0, recursive_fn(n - 1), 1)
示例#42
0
 def true_fn():
     return cond_v2.cond_v2(pred, lambda: x, lambda: x + 1)
示例#43
0
def _cond(pred, true_fn, false_fn, name):
    if _is_old_cond():
        return control_flow_ops.cond(pred, true_fn, false_fn, name=name)
    else:
        return cond_v2.cond_v2(pred, true_fn, false_fn, name=name)
示例#44
0
 def _add_cond(x):
   return cond_v2.cond_v2(
       constant_op.constant(True, name="pred"),
       lambda: x,
       lambda: x + 1)
示例#45
0
 def true_fn():
   return cond_v2.cond_v2(pred, lambda: x, lambda: x + 1)
示例#46
0
  def testContainer(self):
    """Set containers outside & inside of cond_v2.

    Make sure the containers are set correctly for both variable creation
    (tested by variables.Variable) and for stateful ops (tested by FIFOQueue)
    """
    self.skipTest("b/113048653")
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g):

        v0 = variables.Variable([0])
        q0 = data_flow_ops.FIFOQueue(1, dtypes.float32)

        def container(node):
          return node.op.get_attr("container")

        self.assertEqual(compat.as_bytes(""), container(v0))
        self.assertEqual(compat.as_bytes(""), container(q0.queue_ref))

        def true_fn():
          # When this branch is created in cond below,
          # the container should begin with 'l1'
          v1 = variables.Variable([1])
          q1 = data_flow_ops.FIFOQueue(1, dtypes.float32)

          with ops.container("l2t"):
            v2 = variables.Variable([2])
            q2 = data_flow_ops.FIFOQueue(1, dtypes.float32)

          v3 = variables.Variable([1])
          q3 = data_flow_ops.FIFOQueue(1, dtypes.float32)

          self.assertEqual(compat.as_bytes("l1"), container(v1))
          self.assertEqual(compat.as_bytes("l1"), container(q1.queue_ref))
          self.assertEqual(compat.as_bytes("l2t"), container(v2))
          self.assertEqual(compat.as_bytes("l2t"), container(q2.queue_ref))
          self.assertEqual(compat.as_bytes("l1"), container(v3))
          self.assertEqual(compat.as_bytes("l1"), container(q3.queue_ref))

          return constant_op.constant(2.0)

        def false_fn():
          # When this branch is created in cond below,
          # the container should begin with 'l1'
          v1 = variables.Variable([1])
          q1 = data_flow_ops.FIFOQueue(1, dtypes.float32)

          with ops.container("l2f"):
            v2 = variables.Variable([2])
            q2 = data_flow_ops.FIFOQueue(1, dtypes.float32)

          v3 = variables.Variable([1])
          q3 = data_flow_ops.FIFOQueue(1, dtypes.float32)

          self.assertEqual(compat.as_bytes("l1"), container(v1))
          self.assertEqual(compat.as_bytes("l1"), container(q1.queue_ref))
          self.assertEqual(compat.as_bytes("l2f"), container(v2))
          self.assertEqual(compat.as_bytes("l2f"), container(q2.queue_ref))
          self.assertEqual(compat.as_bytes("l1"), container(v3))
          self.assertEqual(compat.as_bytes("l1"), container(q3.queue_ref))

          return constant_op.constant(6.0)

        with ops.container("l1"):
          cnd_true = cond_v2.cond_v2(True, true_fn, false_fn)
          self.assertEquals(cnd_true[0].eval(), 2)

          cnd_false = cond_v2.cond_v2(False, true_fn, false_fn)
          self.assertEquals(cnd_false[0].eval(), 6)

          v4 = variables.Variable([3])
          q4 = data_flow_ops.FIFOQueue(1, dtypes.float32)
        v5 = variables.Variable([4])
        q5 = data_flow_ops.FIFOQueue(1, dtypes.float32)

      self.assertEqual(compat.as_bytes("l1"), container(v4))
      self.assertEqual(compat.as_bytes("l1"), container(q4.queue_ref))
      self.assertEqual(compat.as_bytes(""), container(v5))
      self.assertEqual(compat.as_bytes(""), container(q5.queue_ref))
示例#47
0
def _cond(pred, true_fn, false_fn, name):
  if _is_old_cond():
    return control_flow_ops.cond(pred, true_fn, false_fn, name=name)
  else:
    return cond_v2.cond_v2(pred, true_fn, false_fn, name=name)
示例#48
0
 def _add_cond(x):
     return cond_v2.cond_v2(constant_op.constant(True, name="pred"),
                            lambda: x, lambda: x + 1)
示例#49
0
    def testContainer(self):
        """Set containers outside & inside of cond_v2.

    Make sure the containers are set correctly for both variable creation
    (tested by variables.Variable) and for stateful ops (tested by FIFOQueue)
    """
        self.skipTest("b/113048653")
        with ops.Graph().as_default() as g:
            with self.session(graph=g):

                v0 = variables.Variable([0])
                q0 = data_flow_ops.FIFOQueue(1, dtypes.float32)

                def container(node):
                    return node.op.get_attr("container")

                self.assertEqual(compat.as_bytes(""), container(v0))
                self.assertEqual(compat.as_bytes(""), container(q0.queue_ref))

                def true_fn():
                    # When this branch is created in cond below,
                    # the container should begin with 'l1'
                    v1 = variables.Variable([1])
                    q1 = data_flow_ops.FIFOQueue(1, dtypes.float32)

                    with ops.container("l2t"):
                        v2 = variables.Variable([2])
                        q2 = data_flow_ops.FIFOQueue(1, dtypes.float32)

                    v3 = variables.Variable([1])
                    q3 = data_flow_ops.FIFOQueue(1, dtypes.float32)

                    self.assertEqual(compat.as_bytes("l1"), container(v1))
                    self.assertEqual(compat.as_bytes("l1"),
                                     container(q1.queue_ref))
                    self.assertEqual(compat.as_bytes("l2t"), container(v2))
                    self.assertEqual(compat.as_bytes("l2t"),
                                     container(q2.queue_ref))
                    self.assertEqual(compat.as_bytes("l1"), container(v3))
                    self.assertEqual(compat.as_bytes("l1"),
                                     container(q3.queue_ref))

                    return constant_op.constant(2.0)

                def false_fn():
                    # When this branch is created in cond below,
                    # the container should begin with 'l1'
                    v1 = variables.Variable([1])
                    q1 = data_flow_ops.FIFOQueue(1, dtypes.float32)

                    with ops.container("l2f"):
                        v2 = variables.Variable([2])
                        q2 = data_flow_ops.FIFOQueue(1, dtypes.float32)

                    v3 = variables.Variable([1])
                    q3 = data_flow_ops.FIFOQueue(1, dtypes.float32)

                    self.assertEqual(compat.as_bytes("l1"), container(v1))
                    self.assertEqual(compat.as_bytes("l1"),
                                     container(q1.queue_ref))
                    self.assertEqual(compat.as_bytes("l2f"), container(v2))
                    self.assertEqual(compat.as_bytes("l2f"),
                                     container(q2.queue_ref))
                    self.assertEqual(compat.as_bytes("l1"), container(v3))
                    self.assertEqual(compat.as_bytes("l1"),
                                     container(q3.queue_ref))

                    return constant_op.constant(6.0)

                with ops.container("l1"):
                    cnd_true = cond_v2.cond_v2(constant_op.constant(True),
                                               true_fn, false_fn)
                    self.assertEquals(cnd_true.eval(), 2)

                    cnd_false = cond_v2.cond_v2(constant_op.constant(False),
                                                true_fn, false_fn)
                    self.assertEquals(cnd_false.eval(), 6)

                    v4 = variables.Variable([3])
                    q4 = data_flow_ops.FIFOQueue(1, dtypes.float32)
                v5 = variables.Variable([4])
                q5 = data_flow_ops.FIFOQueue(1, dtypes.float32)

            self.assertEqual(compat.as_bytes("l1"), container(v4))
            self.assertEqual(compat.as_bytes("l1"), container(q4.queue_ref))
            self.assertEqual(compat.as_bytes(""), container(v5))
            self.assertEqual(compat.as_bytes(""), container(q5.queue_ref))
示例#50
0
 def model(b):
     return cond_v2.cond_v2(b, true_fn, false_fn)