Example #1
0
    def testNodeStepperConstructorShouldAllowListOrTupleOrDictOfFetches(self):
        for i in range(6):
            if i == 0:
                fetches = [self.e, [self.f, self.z]]
            elif i == 1:
                fetches = (self.e, (self.f, self.z))
            elif i == 2:
                fetches = {"e": self.e, "fz": {"f": self.f, "z": self.z}}
            elif i == 3:
                fetches = ["e:0", ["f:0", "z:0"]]
            elif i == 4:
                fetches = ("e:0", ("f:0", "z:0"))
            elif i == 5:
                fetches = {"e": "e:0", "fz": {"f": "f:0", "z": "z:0"}}

            stepper = NodeStepper(self.sess, fetches)

            sorted_nodes = stepper.sorted_nodes()
            self.assertEqual(13, len(sorted_nodes))

            # Check the topological order of the sorted nodes.
            self.assertLess(sorted_nodes.index("x"),
                            sorted_nodes.index("x/read"))
            self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("y"))
            self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("z"))
            self.assertLess(sorted_nodes.index("y"), sorted_nodes.index("z"))

            self.assertLess(sorted_nodes.index("a"),
                            sorted_nodes.index("a/read"))
            self.assertLess(sorted_nodes.index("b"),
                            sorted_nodes.index("b/read"))
            self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("c"))
            self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("c"))
            self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("d"))
            self.assertLess(sorted_nodes.index("d"), sorted_nodes.index("e"))
            self.assertLess(sorted_nodes.index("c"), sorted_nodes.index("e"))
            self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("f"))
            self.assertLess(sorted_nodes.index("f_y"), sorted_nodes.index("f"))

            closure_elements = stepper.closure_elements()
            self.assertIn("x/read:0", closure_elements)
            self.assertIn("e:0", closure_elements)
            self.assertIn("f:0", closure_elements)

            self.assertEqual([0], stepper.output_slots_in_closure("x/read"))
            self.assertEqual([0], stepper.output_slots_in_closure("e"))
            self.assertEqual([0], stepper.output_slots_in_closure("f"))

            result = stepper.finalize()
            if i == 0 or i == 1 or i == 3 or i == 4:
                self.assertAllClose(24.0, result[0])
                self.assertAllClose(10.0, result[1][0])
                self.assertAllClose(-4.0, result[1][1])
            elif i == 2 or i == 5:
                self.assertAllClose(24.0, result["e"])
                self.assertAllClose(10.0, result["fz"]["f"])
                self.assertAllClose(-4.0, result["fz"]["z"])
  def testNodeStepperConstructorShouldAllowListOrTupleOrDictOfFetches(self):
    for i in range(6):
      if i == 0:
        fetches = [self.e, [self.f, self.z]]
      elif i == 1:
        fetches = (self.e, (self.f, self.z))
      elif i == 2:
        fetches = {"e": self.e, "fz": {"f": self.f, "z": self.z}}
      elif i == 3:
        fetches = ["e:0", ["f:0", "z:0"]]
      elif i == 4:
        fetches = ("e:0", ("f:0", "z:0"))
      elif i == 5:
        fetches = {"e": "e:0", "fz": {"f": "f:0", "z": "z:0"}}

      stepper = NodeStepper(self.sess, fetches)

      sorted_nodes = stepper.sorted_nodes()
      self.assertEqual(13, len(sorted_nodes))

      # Check the topological order of the sorted nodes.
      self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("x/read"))
      self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("y"))
      self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("z"))
      self.assertLess(sorted_nodes.index("y"), sorted_nodes.index("z"))

      self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("a/read"))
      self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("b/read"))
      self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("c"))
      self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("c"))
      self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("d"))
      self.assertLess(sorted_nodes.index("d"), sorted_nodes.index("e"))
      self.assertLess(sorted_nodes.index("c"), sorted_nodes.index("e"))
      self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("f"))
      self.assertLess(sorted_nodes.index("f_y"), sorted_nodes.index("f"))

      closure_elements = stepper.closure_elements()
      self.assertIn("x/read:0", closure_elements)
      self.assertIn("e:0", closure_elements)
      self.assertIn("f:0", closure_elements)

      self.assertEqual([0], stepper.output_slots_in_closure("x/read"))
      self.assertEqual([0], stepper.output_slots_in_closure("e"))
      self.assertEqual([0], stepper.output_slots_in_closure("f"))

      result = stepper.finalize()
      if i == 0 or i == 1 or i == 3 or i == 4:
        self.assertAllClose(24.0, result[0])
        self.assertAllClose(10.0, result[1][0])
        self.assertAllClose(-4.0, result[1][1])
      elif i == 2 or i == 5:
        self.assertAllClose(24.0, result["e"])
        self.assertAllClose(10.0, result["fz"]["f"])
        self.assertAllClose(-4.0, result["fz"]["z"])
  def testContToNodeWithOutputTensors(self):
    """cont() to an op should cache its output tensors if appropriate."""

    stepper = NodeStepper(self.sess, "optim")

    # In the transitive closure of the stepper, look for an op of which the
    # output tensor also is in the transitive closure.
    # Do not assume a specific op, e.g., ""gradients/e_grad/Reshape_1",
    # because it may vary between builds.
    closure_elements = stepper.closure_elements()
    op_with_output_in_closure = None
    for element_name in closure_elements:
      if element_name + ":0" in closure_elements:
        op_with_output_in_closure = str(element_name)
        break

    self.assertEqual([0],
                     stepper.output_slots_in_closure(op_with_output_in_closure))

    self.assertIsNotNone(op_with_output_in_closure)
    output_tensor = op_with_output_in_closure + ":0"

    # The op "gradients/?_grad/Reshape_1" is in the transitive closure of the
    # stepper, because it is the control input to another o. However, its
    # output tensor "gradients/?_grad/Reshape_1:0" is also in the transitive
    # closure, because it is the (non-control) input of certain ops. Calling
    # cont() on the op should lead to the caching of the tensor handle for
    # the output tensor.
    stepper.cont(op_with_output_in_closure)

    self.assertEqual([output_tensor], stepper.handle_names())
    self.assertSetEqual({op_with_output_in_closure},
                        stepper.handle_node_names())

    # Do a cont() call that uses the cached tensor of
    # "gradients/?_grad/Reshape_1:0".
    stepper.cont(output_tensor)
    self.assertEqual({
        output_tensor: NodeStepper.FEED_TYPE_HANDLE
    }, stepper.last_feed_types())
Example #4
0
  def testContToNodeWithOutputTensors(self):
    """cont() to an op should cache its output tensors if appropriate."""

    stepper = NodeStepper(self.sess, "optim")

    # In the transitive closure of the stepper, look for an op of which the
    # output tensor also is in the transitive closure.
    # Do not assume a specific op, e.g., ""gradients/e_grad/Reshape_1",
    # because it may vary between builds.
    closure_elements = stepper.closure_elements()
    op_with_output_in_closure = None
    for element_name in closure_elements:
      if element_name + ":0" in closure_elements:
        op_with_output_in_closure = str(element_name)
        break

    self.assertEqual([0],
                     stepper.output_slots_in_closure(op_with_output_in_closure))

    self.assertIsNotNone(op_with_output_in_closure)
    output_tensor = op_with_output_in_closure + ":0"

    # The op "gradients/?_grad/Reshape_1" is in the transitive closure of the
    # stepper, because it is the control input to another o. However, its
    # output tensor "gradients/?_grad/Reshape_1:0" is also in the transitive
    # closure, because it is the (non-control) input of certain ops. Calling
    # cont() on the op should lead to the caching of the tensor handle for
    # the output tensor.
    stepper.cont(op_with_output_in_closure)

    self.assertEqual([output_tensor], stepper.handle_names())
    self.assertSetEqual({op_with_output_in_closure},
                        stepper.handle_node_names())

    # Do a cont() call that uses the cached tensor of
    # "gradients/?_grad/Reshape_1:0".
    stepper.cont(output_tensor)
    self.assertEqual({
        output_tensor: NodeStepper.FEED_TYPE_HANDLE
    }, stepper.last_feed_types())