def testAttemptToContToPlaceholderWithTensorNameFeedKeysShouldWork(self): ph0_feed = [[1.0, 2.0], [-3.0, 5.0]] ph1_feed = [[-1.0], [0.5]] stepper = NodeStepper( self.sess, self.y, feed_dict={ self.ph0.name: ph0_feed, self.ph1.name: ph1_feed, }) self.assertAllClose(ph0_feed, stepper.cont(self.ph0)) self.assertEqual({ self.ph0.name: NodeStepper.FEED_TYPE_CLIENT }, stepper.last_feed_types()) self.assertAllClose(ph1_feed, stepper.cont(self.ph1)) self.assertEqual({ self.ph1.name: NodeStepper.FEED_TYPE_CLIENT }, stepper.last_feed_types()) ph0_node = self.sess.graph.as_graph_element("ph0") self.assertAllClose(ph0_feed, stepper.cont(ph0_node)) self.assertEqual({ self.ph0.name: NodeStepper.FEED_TYPE_CLIENT }, stepper.last_feed_types()) self.assertAllClose([[-1.0], [6.0]], stepper.finalize())
def testAttemptToContToPlaceholderWithTensorNameFeedKeysShouldWork(self): ph0_feed = [[1.0, 2.0], [-3.0, 5.0]] ph1_feed = [[-1.0], [0.5]] stepper = NodeStepper(self.sess, self.y, feed_dict={ self.ph0.name: ph0_feed, self.ph1.name: ph1_feed, }) self.assertAllClose(ph0_feed, stepper.cont(self.ph0)) self.assertEqual({self.ph0.name: NodeStepper.FEED_TYPE_CLIENT}, stepper.last_feed_types()) self.assertAllClose(ph1_feed, stepper.cont(self.ph1)) self.assertEqual({self.ph1.name: NodeStepper.FEED_TYPE_CLIENT}, stepper.last_feed_types()) ph0_node = self.sess.graph.as_graph_element("ph0") self.assertAllClose(ph0_feed, stepper.cont(ph0_node)) self.assertEqual({self.ph0.name: NodeStepper.FEED_TYPE_CLIENT}, stepper.last_feed_types()) self.assertAllClose([[-1.0], [6.0]], stepper.finalize())
def testNodeStepperConstructorShouldAllowListOrTupleOrDictOfFetches(self): for i in range(6): if i == 0: fetches = [self.e, [self.f, self.z]] elif i == 1: fetches = (self.e, (self.f, self.z)) elif i == 2: fetches = {"e": self.e, "fz": {"f": self.f, "z": self.z}} elif i == 3: fetches = ["e:0", ["f:0", "z:0"]] elif i == 4: fetches = ("e:0", ("f:0", "z:0")) elif i == 5: fetches = {"e": "e:0", "fz": {"f": "f:0", "z": "z:0"}} stepper = NodeStepper(self.sess, fetches) sorted_nodes = stepper.sorted_nodes() self.assertEqual(13, len(sorted_nodes)) # Check the topological order of the sorted nodes. self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("x/read")) self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("y")) self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("z")) self.assertLess(sorted_nodes.index("y"), sorted_nodes.index("z")) self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("a/read")) self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("b/read")) self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("c")) self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("c")) self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("d")) self.assertLess(sorted_nodes.index("d"), sorted_nodes.index("e")) self.assertLess(sorted_nodes.index("c"), sorted_nodes.index("e")) self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("f")) self.assertLess(sorted_nodes.index("f_y"), sorted_nodes.index("f")) closure_elements = stepper.closure_elements() self.assertIn("x/read:0", closure_elements) self.assertIn("e:0", closure_elements) self.assertIn("f:0", closure_elements) self.assertEqual([0], stepper.output_slots_in_closure("x/read")) self.assertEqual([0], stepper.output_slots_in_closure("e")) self.assertEqual([0], stepper.output_slots_in_closure("f")) result = stepper.finalize() if i == 0 or i == 1 or i == 3 or i == 4: self.assertAllClose(24.0, result[0]) self.assertAllClose(10.0, result[1][0]) self.assertAllClose(-4.0, result[1][1]) elif i == 2 or i == 5: self.assertAllClose(24.0, result["e"]) self.assertAllClose(10.0, result["fz"]["f"]) self.assertAllClose(-4.0, result["fz"]["z"])
def testNodeStepperConstructorShouldAllowListOrTupleOrDictOfFetches(self): for i in range(6): if i == 0: fetches = [self.e, [self.f, self.z]] elif i == 1: fetches = (self.e, (self.f, self.z)) elif i == 2: fetches = {"e": self.e, "fz": {"f": self.f, "z": self.z}} elif i == 3: fetches = ["e:0", ["f:0", "z:0"]] elif i == 4: fetches = ("e:0", ("f:0", "z:0")) elif i == 5: fetches = {"e": "e:0", "fz": {"f": "f:0", "z": "z:0"}} stepper = NodeStepper(self.sess, fetches) sorted_nodes = stepper.sorted_nodes() self.assertEqual(13, len(sorted_nodes)) # Check the topological order of the sorted nodes. self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("x/read")) self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("y")) self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("z")) self.assertLess(sorted_nodes.index("y"), sorted_nodes.index("z")) self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("a/read")) self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("b/read")) self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("c")) self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("c")) self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("d")) self.assertLess(sorted_nodes.index("d"), sorted_nodes.index("e")) self.assertLess(sorted_nodes.index("c"), sorted_nodes.index("e")) self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("f")) self.assertLess(sorted_nodes.index("f_y"), sorted_nodes.index("f")) closure_elements = stepper.closure_elements() self.assertIn("x/read:0", closure_elements) self.assertIn("e:0", closure_elements) self.assertIn("f:0", closure_elements) self.assertEqual([0], stepper.output_slots_in_closure("x/read")) self.assertEqual([0], stepper.output_slots_in_closure("e")) self.assertEqual([0], stepper.output_slots_in_closure("f")) result = stepper.finalize() if i == 0 or i == 1 or i == 3 or i == 4: self.assertAllClose(24.0, result[0]) self.assertAllClose(10.0, result[1][0]) self.assertAllClose(-4.0, result[1][1]) elif i == 2 or i == 5: self.assertAllClose(24.0, result["e"]) self.assertAllClose(10.0, result["fz"]["f"]) self.assertAllClose(-4.0, result["fz"]["z"])
def testFinalizeWithPreviousOverrides(self): stepper = NodeStepper(self.sess, self.e) stepper.override_tensor("a/read:0", 20.0) self.assertEqual(["a/read:0"], stepper.override_names()) # Should reflect the overriding value. self.assertAllClose(24000.0, stepper.cont("e:0")) self.assertEqual({"a/read:0": NodeStepper.FEED_TYPE_OVERRIDE}, stepper.last_feed_types()) # Finalize call should have ignored the overriding value. self.assertAllClose(24.0, stepper.finalize())
def testFinalizeWithPreviousOverrides(self): stepper = NodeStepper(self.sess, self.e) stepper.override_tensor("a/read:0", 20.0) self.assertEqual(["a/read:0"], stepper.override_names()) # Should reflect the overriding value. self.assertAllClose(24000.0, stepper.cont("e:0")) self.assertEqual({ "a/read:0": NodeStepper.FEED_TYPE_OVERRIDE }, stepper.last_feed_types()) # Finalize call should have ignored the overriding value. self.assertAllClose(24.0, stepper.finalize())
def testFinalize(self): """Test finalize() to restore variables and run the original fetch.""" stepper = NodeStepper(self.sess, "optim") # Invoke update_b before calling finalize. stepper.cont("optim/update_b/ApplyGradientDescent", restore_variable_values=True) result = stepper.finalize() self.assertIsNone(result) # The results of the Variable updates should be the same as if no cont() # call has occurred on update_b. self.assertAllClose(0.84, self.sess.run(self.a)) self.assertAllClose(1.84, self.sess.run(self.b)) self.assertAllClose(3.96, self.sess.run(self.c))
def testFinalize(self): """Test finalize() to restore variables and run the original fetch.""" stepper = NodeStepper(self.sess, "optim") # Invoke update_b before calling finalize. stepper.cont("optim/update_b/ApplyGradientDescent", restore_variable_values=True) result = stepper.finalize() self.assertIsNone(result) # The results of the Variable updates should be the same as if no cont() # call has occurred on update_b. self.assertAllClose(0.84, self.sess.run(self.a)) self.assertAllClose(1.84, self.sess.run(self.b)) self.assertAllClose(3.96, self.sess.run(self.c))