def testAttemptToContToPlaceholderWithTensorNameFeedKeysShouldWork(self):

    ph0_feed = [[1.0, 2.0], [-3.0, 5.0]]
    ph1_feed = [[-1.0], [0.5]]
    stepper = NodeStepper(
        self.sess,
        self.y,
        feed_dict={
            self.ph0.name: ph0_feed,
            self.ph1.name: ph1_feed,
        })

    self.assertAllClose(ph0_feed, stepper.cont(self.ph0))
    self.assertEqual({
        self.ph0.name: NodeStepper.FEED_TYPE_CLIENT
    }, stepper.last_feed_types())

    self.assertAllClose(ph1_feed, stepper.cont(self.ph1))
    self.assertEqual({
        self.ph1.name: NodeStepper.FEED_TYPE_CLIENT
    }, stepper.last_feed_types())

    ph0_node = self.sess.graph.as_graph_element("ph0")
    self.assertAllClose(ph0_feed, stepper.cont(ph0_node))
    self.assertEqual({
        self.ph0.name: NodeStepper.FEED_TYPE_CLIENT
    }, stepper.last_feed_types())

    self.assertAllClose([[-1.0], [6.0]], stepper.finalize())
Example #2
0
    def testAttemptToContToPlaceholderWithTensorNameFeedKeysShouldWork(self):

        ph0_feed = [[1.0, 2.0], [-3.0, 5.0]]
        ph1_feed = [[-1.0], [0.5]]
        stepper = NodeStepper(self.sess,
                              self.y,
                              feed_dict={
                                  self.ph0.name: ph0_feed,
                                  self.ph1.name: ph1_feed,
                              })

        self.assertAllClose(ph0_feed, stepper.cont(self.ph0))
        self.assertEqual({self.ph0.name: NodeStepper.FEED_TYPE_CLIENT},
                         stepper.last_feed_types())

        self.assertAllClose(ph1_feed, stepper.cont(self.ph1))
        self.assertEqual({self.ph1.name: NodeStepper.FEED_TYPE_CLIENT},
                         stepper.last_feed_types())

        ph0_node = self.sess.graph.as_graph_element("ph0")
        self.assertAllClose(ph0_feed, stepper.cont(ph0_node))
        self.assertEqual({self.ph0.name: NodeStepper.FEED_TYPE_CLIENT},
                         stepper.last_feed_types())

        self.assertAllClose([[-1.0], [6.0]], stepper.finalize())
Example #3
0
    def testNodeStepperConstructorShouldAllowListOrTupleOrDictOfFetches(self):
        for i in range(6):
            if i == 0:
                fetches = [self.e, [self.f, self.z]]
            elif i == 1:
                fetches = (self.e, (self.f, self.z))
            elif i == 2:
                fetches = {"e": self.e, "fz": {"f": self.f, "z": self.z}}
            elif i == 3:
                fetches = ["e:0", ["f:0", "z:0"]]
            elif i == 4:
                fetches = ("e:0", ("f:0", "z:0"))
            elif i == 5:
                fetches = {"e": "e:0", "fz": {"f": "f:0", "z": "z:0"}}

            stepper = NodeStepper(self.sess, fetches)

            sorted_nodes = stepper.sorted_nodes()
            self.assertEqual(13, len(sorted_nodes))

            # Check the topological order of the sorted nodes.
            self.assertLess(sorted_nodes.index("x"),
                            sorted_nodes.index("x/read"))
            self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("y"))
            self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("z"))
            self.assertLess(sorted_nodes.index("y"), sorted_nodes.index("z"))

            self.assertLess(sorted_nodes.index("a"),
                            sorted_nodes.index("a/read"))
            self.assertLess(sorted_nodes.index("b"),
                            sorted_nodes.index("b/read"))
            self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("c"))
            self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("c"))
            self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("d"))
            self.assertLess(sorted_nodes.index("d"), sorted_nodes.index("e"))
            self.assertLess(sorted_nodes.index("c"), sorted_nodes.index("e"))
            self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("f"))
            self.assertLess(sorted_nodes.index("f_y"), sorted_nodes.index("f"))

            closure_elements = stepper.closure_elements()
            self.assertIn("x/read:0", closure_elements)
            self.assertIn("e:0", closure_elements)
            self.assertIn("f:0", closure_elements)

            self.assertEqual([0], stepper.output_slots_in_closure("x/read"))
            self.assertEqual([0], stepper.output_slots_in_closure("e"))
            self.assertEqual([0], stepper.output_slots_in_closure("f"))

            result = stepper.finalize()
            if i == 0 or i == 1 or i == 3 or i == 4:
                self.assertAllClose(24.0, result[0])
                self.assertAllClose(10.0, result[1][0])
                self.assertAllClose(-4.0, result[1][1])
            elif i == 2 or i == 5:
                self.assertAllClose(24.0, result["e"])
                self.assertAllClose(10.0, result["fz"]["f"])
                self.assertAllClose(-4.0, result["fz"]["z"])
  def testNodeStepperConstructorShouldAllowListOrTupleOrDictOfFetches(self):
    for i in range(6):
      if i == 0:
        fetches = [self.e, [self.f, self.z]]
      elif i == 1:
        fetches = (self.e, (self.f, self.z))
      elif i == 2:
        fetches = {"e": self.e, "fz": {"f": self.f, "z": self.z}}
      elif i == 3:
        fetches = ["e:0", ["f:0", "z:0"]]
      elif i == 4:
        fetches = ("e:0", ("f:0", "z:0"))
      elif i == 5:
        fetches = {"e": "e:0", "fz": {"f": "f:0", "z": "z:0"}}

      stepper = NodeStepper(self.sess, fetches)

      sorted_nodes = stepper.sorted_nodes()
      self.assertEqual(13, len(sorted_nodes))

      # Check the topological order of the sorted nodes.
      self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("x/read"))
      self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("y"))
      self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("z"))
      self.assertLess(sorted_nodes.index("y"), sorted_nodes.index("z"))

      self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("a/read"))
      self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("b/read"))
      self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("c"))
      self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("c"))
      self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("d"))
      self.assertLess(sorted_nodes.index("d"), sorted_nodes.index("e"))
      self.assertLess(sorted_nodes.index("c"), sorted_nodes.index("e"))
      self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("f"))
      self.assertLess(sorted_nodes.index("f_y"), sorted_nodes.index("f"))

      closure_elements = stepper.closure_elements()
      self.assertIn("x/read:0", closure_elements)
      self.assertIn("e:0", closure_elements)
      self.assertIn("f:0", closure_elements)

      self.assertEqual([0], stepper.output_slots_in_closure("x/read"))
      self.assertEqual([0], stepper.output_slots_in_closure("e"))
      self.assertEqual([0], stepper.output_slots_in_closure("f"))

      result = stepper.finalize()
      if i == 0 or i == 1 or i == 3 or i == 4:
        self.assertAllClose(24.0, result[0])
        self.assertAllClose(10.0, result[1][0])
        self.assertAllClose(-4.0, result[1][1])
      elif i == 2 or i == 5:
        self.assertAllClose(24.0, result["e"])
        self.assertAllClose(10.0, result["fz"]["f"])
        self.assertAllClose(-4.0, result["fz"]["z"])
Example #5
0
    def testFinalizeWithPreviousOverrides(self):
        stepper = NodeStepper(self.sess, self.e)

        stepper.override_tensor("a/read:0", 20.0)
        self.assertEqual(["a/read:0"], stepper.override_names())

        # Should reflect the overriding value.
        self.assertAllClose(24000.0, stepper.cont("e:0"))
        self.assertEqual({"a/read:0": NodeStepper.FEED_TYPE_OVERRIDE},
                         stepper.last_feed_types())

        # Finalize call should have ignored the overriding value.
        self.assertAllClose(24.0, stepper.finalize())
Example #6
0
  def testFinalizeWithPreviousOverrides(self):
    stepper = NodeStepper(self.sess, self.e)

    stepper.override_tensor("a/read:0", 20.0)
    self.assertEqual(["a/read:0"], stepper.override_names())

    # Should reflect the overriding value.
    self.assertAllClose(24000.0, stepper.cont("e:0"))
    self.assertEqual({
        "a/read:0": NodeStepper.FEED_TYPE_OVERRIDE
    }, stepper.last_feed_types())

    # Finalize call should have ignored the overriding value.
    self.assertAllClose(24.0, stepper.finalize())
Example #7
0
    def testFinalize(self):
        """Test finalize() to restore variables and run the original fetch."""

        stepper = NodeStepper(self.sess, "optim")

        # Invoke update_b before calling finalize.
        stepper.cont("optim/update_b/ApplyGradientDescent",
                     restore_variable_values=True)

        result = stepper.finalize()
        self.assertIsNone(result)

        # The results of the Variable updates should be the same as if no cont()
        # call has occurred on update_b.
        self.assertAllClose(0.84, self.sess.run(self.a))
        self.assertAllClose(1.84, self.sess.run(self.b))
        self.assertAllClose(3.96, self.sess.run(self.c))
Example #8
0
  def testFinalize(self):
    """Test finalize() to restore variables and run the original fetch."""

    stepper = NodeStepper(self.sess, "optim")

    # Invoke update_b before calling finalize.
    stepper.cont("optim/update_b/ApplyGradientDescent",
                 restore_variable_values=True)

    result = stepper.finalize()
    self.assertIsNone(result)

    # The results of the Variable updates should be the same as if no cont()
    # call has occurred on update_b.
    self.assertAllClose(0.84, self.sess.run(self.a))
    self.assertAllClose(1.84, self.sess.run(self.b))
    self.assertAllClose(3.96, self.sess.run(self.c))