def testIsFeedableShouldGiveCorrectAnswers(self): stepper = NodeStepper(self.sess, self.e) self.assertTrue(stepper.is_feedable("a/read:0")) self.assertTrue(stepper.is_feedable("b/read:0")) self.assertTrue(stepper.is_feedable("c:0")) self.assertTrue(stepper.is_feedable("d:0"))
def testRemoveNonexistentOverrideValue(self): stepper = NodeStepper(self.sess, self.e) self.assertEqual([], stepper.override_names()) with self.assertRaisesRegexp( ValueError, "No overriding value exists for tensor \"c:0\""): stepper.remove_override("c:0")
def testUsingNodesNotUsingIntermediateTensors(self): stepper = NodeStepper(self.sess, self.e) # There should be no handles before any cont() calls. self.assertEqual([], stepper.handle_names()) self.assertSetEqual(set(), stepper.handle_node_names()) # Before the cont() call, the stepper should not have access to the value # of c:0. with self.assertRaisesRegexp( ValueError, "This stepper instance does not have access to the value of tensor " "\"c:0\""): stepper.get_tensor_value("c:0") # Using the node/tensor itself, instead of the name str, should work on # cont(). result = stepper.cont(self.c) self.assertAllClose(6.0, result) self.assertEqual({}, stepper.last_feed_types()) self.assertEqual(["c:0"], stepper.handle_names()) self.assertEqual({"c"}, stepper.handle_node_names()) # After the cont() call, the stepper should have access to the value of c:0 # via a tensor handle. self.assertAllClose(6.0, stepper.get_tensor_value("c:0")) result = stepper.cont(self.e) self.assertAllClose(24.0, result) self.assertEqual({ "c:0": NodeStepper.FEED_TYPE_HANDLE }, stepper.last_feed_types())
def testOverrideValueTwice(self): with NodeStepper(self.sess, self.e) as stepper: # Override once. stepper.override_tensor("c:0", 7.0) self.assertAllClose(28.0, stepper.cont(self.e)) self.assertEqual({ "c:0": NodeStepper.FEED_TYPE_OVERRIDE }, stepper.last_feed_types()) self.assertEqual(["e:0"], stepper.handle_names()) self.assertSetEqual({"e"}, stepper.handle_node_names()) self.assertEqual(["c:0"], stepper.override_names()) # Calling cont(self.e) again. This time the cached tensor handle of e # should be used. self.assertEqual(28.0, stepper.cont(self.e)) self.assertEqual({ "e:0": NodeStepper.FEED_TYPE_HANDLE }, stepper.last_feed_types()) # Override c again. This should have invalidated the cache for e. stepper.override_tensor("c:0", 8.0) self.assertEqual([], stepper.handle_names()) self.assertEqual(set(), stepper.handle_node_names()) self.assertEqual(["c:0"], stepper.override_names()) self.assertAllClose(32.0, stepper.cont(self.e)) self.assertEqual({ "c:0": NodeStepper.FEED_TYPE_OVERRIDE, "d:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE, }, stepper.last_feed_types())
def testUpdateTwiceRestoreVariable(self): with NodeStepper(self.sess, "optim") as stepper: result = stepper.cont( "optim/update_a/ApplyGradientDescent", invalidate_from_updated_variables=True, restore_variable_values=True) self.assertIsNone(result) self.assertSetEqual({"a:0"}, stepper.last_updated()) self.assertEqual({"a:0"}, stepper.dirty_variables()) result = stepper.cont( "optim/update_b/ApplyGradientDescent", invalidate_from_updated_variables=True, restore_variable_values=True) self.assertIsNone(result) # Variables a and c should have been restored and hence no longer dirty. # Variable b should have been marked as dirty. self.assertSetEqual({"b:0"}, stepper.last_updated()) self.assertEqual({"b:0"}, stepper.dirty_variables()) # The result of the update should be identitcal to as if only update_b is # run. self.assertAllClose(1.0, self.sess.run(self.a)) self.assertAllClose(1.84, self.sess.run(self.b)) self.assertAllClose(4.0, self.sess.run(self.c))
def testContToUpdateA(self): stepper = NodeStepper(self.sess, "optim") result = stepper.cont("a:0") self.assertAllClose(1.0, result) self.assertEqual({}, stepper.last_feed_types()) result = stepper.cont("optim/learning_rate:0") self.assertAllClose(0.01, result) self.assertEqual({}, stepper.last_feed_types()) # Before any cont calls on ApplyGradientDescent, there should be no "dirty" # variables. self.assertEqual(set(), stepper.dirty_variables()) # First, all the two control inputs to optim. result = stepper.cont("optim/update_a/ApplyGradientDescent") # Now variable a should have been marked as dirty due to the update # by optim/update_a/ApplyGradientDescent. self.assertEqual({"a:0"}, stepper.dirty_variables()) self.assertIsNone(result) self.assertEqual({ "optim/learning_rate:0": NodeStepper.FEED_TYPE_HANDLE }, stepper.last_feed_types()) # Check that Variable "a" has been updated properly, but "b", "c" and "d" # remain the same. # For backprop on Variable a: # Because f = a * b * b * c, df / da = b * b * c. # 1.0 - learning_rate * b * b * c # = 1.0 - 0.01 * 2.0 * 2.0 * 4.0 = 0.84. self.assertAllClose(0.84, self.sess.run(self.a)) self.assertAllClose(2.0, self.sess.run(self.b)) self.assertAllClose(4.0, self.sess.run(self.c))
def testOverrideAndContToSameTensor(self): with NodeStepper(self.sess, self.e) as stepper: result = stepper.cont(self.c) self.assertAllClose(6.0, result) self.assertEqual({}, stepper.last_feed_types()) self.assertEqual(["c:0"], stepper.handle_names()) self.assertSetEqual({"c"}, stepper.handle_node_names()) self.assertAllClose(6.0, stepper.cont(self.c)) # The last cont() call should use the tensor handle directly. self.assertEqual({ "c:0": NodeStepper.FEED_TYPE_HANDLE }, stepper.last_feed_types()) # Override c:0. stepper.override_tensor("c:0", 7.0) # As a result of the override, the tensor handle should have been # invalidated. self.assertEqual([], stepper.handle_names()) self.assertSetEqual(set(), stepper.handle_node_names()) result = stepper.cont(self.c) self.assertAllClose(7.0, result) self.assertEqual({ "c:0": NodeStepper.FEED_TYPE_OVERRIDE }, stepper.last_feed_types())
def testContWithPlaceholders(self): if test_util.is_gpu_available(): self.skipTest("b/123446705 this causes a segfault on GPU") with NodeStepper( self.sess, self.y, feed_dict={ self.ph0: [[1.0, 2.0], [-3.0, 5.0]], self.ph1: [[-1.0], [0.5]] }) as stepper: self.assertEqual(4, len(stepper.sorted_nodes())) self.assertSetEqual({"ph0:0", "ph1:0", "x:0", "y:0"}, set(stepper.closure_elements())) result = stepper.cont(self.x) self.assertAllClose([[0.0], [5.5]], result) self.assertEqual({ "ph0:0": NodeStepper.FEED_TYPE_CLIENT, "ph1:0": NodeStepper.FEED_TYPE_CLIENT, }, stepper.last_feed_types()) self.assertEqual(["x:0"], stepper.handle_names()) self.assertSetEqual({"x"}, stepper.handle_node_names()) result = stepper.cont(self.y) self.assertAllClose([[-1.0], [6.0]], result) self.assertEqual({ "x:0": NodeStepper.FEED_TYPE_HANDLE, "ph1:0": NodeStepper.FEED_TYPE_CLIENT, }, stepper.last_feed_types())
def testAttemptToContToPlaceholderWithTensorFeedKeysShouldWork(self): """Continuing to a placeholder should be allowed, using client feed.""" ph0_feed = [[1.0, 2.0], [-3.0, 5.0]] ph1_feed = [[-1.0], [0.5]] stepper = NodeStepper( self.sess, self.y, feed_dict={ self.ph0: ph0_feed, self.ph1: ph1_feed, }) self.assertAllClose(ph0_feed, stepper.cont(self.ph0)) self.assertEqual({ self.ph0.name: NodeStepper.FEED_TYPE_CLIENT }, stepper.last_feed_types()) self.assertAllClose(ph1_feed, stepper.cont(self.ph1)) self.assertEqual({ self.ph1.name: NodeStepper.FEED_TYPE_CLIENT }, stepper.last_feed_types()) ph0_node = self.sess.graph.as_graph_element("ph0") self.assertAllClose(ph0_feed, stepper.cont(ph0_node)) self.assertEqual({ self.ph0.name: NodeStepper.FEED_TYPE_CLIENT }, stepper.last_feed_types()) self.assertAllClose([[-1.0], [6.0]], stepper.finalize())
def testAttemptToContToPlaceholderWithTensorNameFeedKeysShouldWork(self): ph0_feed = [[1.0, 2.0], [-3.0, 5.0]] ph1_feed = [[-1.0], [0.5]] with NodeStepper( self.sess, self.y, feed_dict={ self.ph0.name: ph0_feed, self.ph1.name: ph1_feed, }) as stepper: self.assertAllClose(ph0_feed, stepper.cont(self.ph0)) self.assertEqual({ self.ph0.name: NodeStepper.FEED_TYPE_CLIENT }, stepper.last_feed_types()) self.assertAllClose(ph1_feed, stepper.cont(self.ph1)) self.assertEqual({ self.ph1.name: NodeStepper.FEED_TYPE_CLIENT }, stepper.last_feed_types()) ph0_node = self.sess.graph.as_graph_element("ph0") self.assertAllClose(ph0_feed, stepper.cont(ph0_node)) self.assertEqual({ self.ph0.name: NodeStepper.FEED_TYPE_CLIENT }, stepper.last_feed_types()) self.assertAllClose([[-1.0], [6.0]], stepper.finalize())
def testUsingNamesNotUsingIntermediateTensors(self): if test_util.is_gpu_available(): self.skipTest("b/123446705 this causes a segfault on GPU") with NodeStepper(self.sess, "e:0") as stepper: # The first cont() call should have used no feeds. result = stepper.cont("c:0") self.assertAllClose(6.0, result) self.assertItemsEqual(["a/read:0", "b/read:0"], stepper.intermediate_tensor_names()) self.assertAllClose(2.0, stepper.get_tensor_value("a/read:0")) self.assertAllClose(3.0, stepper.get_tensor_value("b/read:0")) self.assertEqual({}, stepper.last_feed_types()) # The second cont() call should have used the tensor handle from the # previous cont() call. result = stepper.cont("e:0") self.assertAllClose(24.0, result) self.assertItemsEqual(["a/read:0", "b/read:0", "d:0"], stepper.intermediate_tensor_names()) self.assertAllClose(2.0, stepper.get_tensor_value("a/read:0")) self.assertAllClose(3.0, stepper.get_tensor_value("b/read:0")) self.assertAllClose(4.0, stepper.get_tensor_value("d:0")) self.assertEqual({ "c:0": NodeStepper.FEED_TYPE_HANDLE, "a/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE, }, stepper.last_feed_types())
def testContToUpdateInvalidatesDumpedIntermediates(self): with NodeStepper(self.sess, [self.q, self.v_add]) as stepper: self.assertAllClose(400.0, stepper.cont("q:0")) self.assertItemsEqual(["v/read:0", "p:0"], stepper.intermediate_tensor_names()) self.assertAllClose(10.0, stepper.get_tensor_value("v/read:0")) self.assertAllClose(20.0, stepper.get_tensor_value("p:0")) self.assertAllClose( 12.0, stepper.cont( self.v_add, invalidate_from_updated_variables=True)) self.assertAllClose(12.0, self.sess.run(self.v)) self.assertSetEqual({self.v.name}, stepper.last_updated()) self.assertItemsEqual(["v:0"], stepper.dirty_variables()) # Updating the value of v by calling v_add should have invalidated the # dumped intermediate tensors for v/read:0 and p:0. self.assertItemsEqual(["delta:0"], stepper.intermediate_tensor_names()) with self.assertRaisesRegexp( ValueError, r"This stepper instance does not have access to the value of tensor " r"\"p:0\""): stepper.get_tensor_value("p:0") # The next cont to q should not have used any dumped intermediate tensors # and its result should reflect the updated value. self.assertAllClose(576.0, stepper.cont("q:0")) self.assertSetEqual(set(), stepper.last_updated()) self.assertEqual({}, stepper.last_feed_types())
def testContAfterUpdateWithoutRestoringVariableValue(self): with NodeStepper(self.sess, "optim") as stepper: # First, update Variable a from 1.0 to 0.84. result = stepper.cont( "optim/update_a/ApplyGradientDescent", invalidate_from_updated_variables=True, restore_variable_values=True) self.assertIsNone(result) self.assertSetEqual({"a:0"}, stepper.last_updated()) self.assertEqual(set(["a:0"]), stepper.dirty_variables()) self.assertAllClose(0.84, self.sess.run(self.a)) self.assertAllClose(2.0, self.sess.run(self.b)) self.assertAllClose(4.0, self.sess.run(self.c)) # Tracking of the updated variables should have invalidated all # intermediate tensors downstream to a:0. self.assertNotIn("a/read:0", stepper.intermediate_tensor_names()) self.assertNotIn("d:0", stepper.intermediate_tensor_names()) # Second, update Variable b without the default restore_variable_values. result = stepper.cont( "optim/update_b/ApplyGradientDescent", restore_variable_values=False) self.assertIsNone(result) # For the backprop on Variable b under the updated value of a: # 2.0 - learning_rate * 2 * a' * b * c # = 2.0 - 0.01 * 2 * 0.84 * 2.0 * 4.0 = 1.8656 self.assertAllClose(0.84, self.sess.run(self.a)) self.assertAllClose(1.8656, self.sess.run(self.b)) self.assertAllClose(4.0, self.sess.run(self.c))
def testContNotInvalidatingFromVariableUpdatesWorksForNextUpdate(self): with NodeStepper(self.sess, "optim") as stepper: self.assertIsNone(stepper.cont( "optim/update_a/ApplyGradientDescent", invalidate_from_updated_variables=False)) # Even though invalidate_from_updated_variables is set to False, dirty # variables should still have been tracked. self.assertSetEqual({"a:0"}, stepper.last_updated()) self.assertEqual({"a:0"}, stepper.dirty_variables()) self.assertIn("a/read:0", stepper.intermediate_tensor_names()) self.assertIn("b/read:0", stepper.intermediate_tensor_names()) self.assertIn("c/read:0", stepper.intermediate_tensor_names()) self.assertIn("d:0", stepper.intermediate_tensor_names()) self.assertIn("e:0", stepper.intermediate_tensor_names()) self.assertIn("optim/learning_rate:0", stepper.intermediate_tensor_names()) self.assertNotIn("a:0", stepper.intermediate_tensor_names()) self.assertNotIn("b:0", stepper.intermediate_tensor_names()) self.assertNotIn("c:0", stepper.intermediate_tensor_names()) self.assertAllClose(0.84, self.sess.run(self.a)) self.assertAllClose(2.0, self.sess.run(self.b)) self.assertAllClose(4.0, self.sess.run(self.c)) # For the backprop on Variable b, the result should reflect the original # value of Variable a, even though Variable a has actually been updated. # 2.0 - learning_rate * 2 * a * b * c # = 2.0 - 0.01 * 2 * 1.0 * 2.0 * 4.0 = 1.84 self.assertIsNone(stepper.cont( "optim/update_b/ApplyGradientDescent", invalidate_from_updated_variables=False, restore_variable_values=False)) self.assertAllClose(0.84, self.sess.run(self.a)) self.assertAllClose(1.84, self.sess.run(self.b)) self.assertAllClose(4.0, self.sess.run(self.c))
def testContWithPlaceholders(self): with NodeStepper( self.sess, self.y, feed_dict={ self.ph0: [[1.0, 2.0], [-3.0, 5.0]], self.ph1: [[-1.0], [0.5]] }) as stepper: self.assertEqual(4, len(stepper.sorted_nodes())) self.assertSetEqual({"ph0:0", "ph1:0", "x:0", "y:0"}, set(stepper.closure_elements())) result = stepper.cont(self.x) self.assertAllClose([[0.0], [5.5]], result) self.assertEqual({ "ph0:0": NodeStepper.FEED_TYPE_CLIENT, "ph1:0": NodeStepper.FEED_TYPE_CLIENT, }, stepper.last_feed_types()) self.assertEqual(["x:0"], stepper.handle_names()) self.assertSetEqual({"x"}, stepper.handle_node_names()) result = stepper.cont(self.y) self.assertAllClose([[-1.0], [6.0]], result) self.assertEqual({ "x:0": NodeStepper.FEED_TYPE_HANDLE, "ph1:0": NodeStepper.FEED_TYPE_CLIENT, }, stepper.last_feed_types())
def testTransitiveClosureWithCrossLinksShouldHaveCorrectOrder(self): with NodeStepper(self.sess, "z:0") as stepper: sorted_nodes = stepper.sorted_nodes() self.assertEqual(4, len(sorted_nodes)) self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("x/read")) self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("y")) self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("z")) self.assertLess(sorted_nodes.index("y"), sorted_nodes.index("z"))
def testRepeatedCallsToAssignAddDownStreamDoesNotUpdateVariableAgain(self): with NodeStepper(self.sess, self.v_add_plus_one) as stepper: stepper.cont(self.v_add_plus_one) self.assertSetEqual({self.v.name}, stepper.last_updated()) self.assertAllClose(12.0, stepper.cont(self.v)) stepper.cont(self.v_add_plus_one) self.assertSetEqual(set(), stepper.last_updated()) self.assertEqual({"v_add_plus_one:0": NodeStepper.FEED_TYPE_HANDLE}, stepper.last_feed_types()) self.assertAllClose(12.0, stepper.cont(self.v))
def testDisablingUseDumpedIntermediatesWorks(self): with NodeStepper(self.sess, ["e:0", "f:0"]) as stepper: stepper.cont("c:0") self.assertItemsEqual(["a/read:0", "b/read:0"], stepper.intermediate_tensor_names()) self.assertAllClose(2.0, stepper.get_tensor_value("a/read:0")) self.assertAllClose(3.0, stepper.get_tensor_value("b/read:0")) self.assertAllClose(10.0, stepper.cont("f:0", use_dumped_intermediates=False)) self.assertEqual({}, stepper.last_feed_types())
def testRestoreVariableValues(self): """Test restore_variable_values() restores the old values of variables.""" stepper = NodeStepper(self.sess, "optim") stepper.cont( "optim/update_b/ApplyGradientDescent", restore_variable_values=True) self.assertAllClose(1.84, self.sess.run(self.b)) stepper.restore_variable_values() self.assertAllClose(2.0, self.sess.run(self.b))
def testNodeStepperConstructorShouldAllowListOrTupleOrDictOfFetches(self): for i in range(6): if i == 0: fetches = [self.e, [self.f, self.z]] elif i == 1: fetches = (self.e, (self.f, self.z)) elif i == 2: fetches = {"e": self.e, "fz": {"f": self.f, "z": self.z}} elif i == 3: fetches = ["e:0", ["f:0", "z:0"]] elif i == 4: fetches = ("e:0", ("f:0", "z:0")) elif i == 5: fetches = {"e": "e:0", "fz": {"f": "f:0", "z": "z:0"}} stepper = NodeStepper(self.sess, fetches) sorted_nodes = stepper.sorted_nodes() self.assertEqual(13, len(sorted_nodes)) # Check the topological order of the sorted nodes. self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("x/read")) self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("y")) self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("z")) self.assertLess(sorted_nodes.index("y"), sorted_nodes.index("z")) self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("a/read")) self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("b/read")) self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("c")) self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("c")) self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("d")) self.assertLess(sorted_nodes.index("d"), sorted_nodes.index("e")) self.assertLess(sorted_nodes.index("c"), sorted_nodes.index("e")) self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("f")) self.assertLess(sorted_nodes.index("f_y"), sorted_nodes.index("f")) closure_elements = stepper.closure_elements() self.assertIn("x/read:0", closure_elements) self.assertIn("e:0", closure_elements) self.assertIn("f:0", closure_elements) self.assertEqual([0], stepper.output_slots_in_closure("x/read")) self.assertEqual([0], stepper.output_slots_in_closure("e")) self.assertEqual([0], stepper.output_slots_in_closure("f")) result = stepper.finalize() if i == 0 or i == 1 or i == 3 or i == 4: self.assertAllClose(24.0, result[0]) self.assertAllClose(10.0, result[1][0]) self.assertAllClose(-4.0, result[1][1]) elif i == 2 or i == 5: self.assertAllClose(24.0, result["e"]) self.assertAllClose(10.0, result["fz"]["f"]) self.assertAllClose(-4.0, result["fz"]["z"])
def testIsPlaceholdersShouldGiveCorrectAnswers(self): with NodeStepper(self.sess, self.y) as stepper: self.assertTrue(stepper.is_placeholder(self.ph0.name)) self.assertTrue(stepper.is_placeholder(self.ph1.name)) self.assertFalse(stepper.is_placeholder(self.x.name)) self.assertFalse(stepper.is_placeholder(self.y.name)) with self.assertRaisesRegexp(ValueError, "A is not in the transitive closure"): self.assertFalse(stepper.is_placeholder("A"))
def testRemoveOverrideValue(self): with NodeStepper(self.sess, self.e) as stepper: result = stepper.cont(self.c) self.assertAllClose(6.0, result) self.assertEqual({}, stepper.last_feed_types()) # The previous cont() step should have generated a cached tensor handle. self.assertEqual(["c:0"], stepper.handle_names()) self.assertSetEqual({"c"}, stepper.handle_node_names()) # Override c:0. stepper.override_tensor("c:0", 7.0) # The overriding should have invalidated the tensor handle. self.assertEqual([], stepper.handle_names()) self.assertSetEqual(set(), stepper.handle_node_names()) self.assertEqual(["c:0"], stepper.override_names()) result = stepper.cont(self.e) self.assertAllClose(28.0, result) # Should reflect the overriding value. self.assertEqual({ "c:0": NodeStepper.FEED_TYPE_OVERRIDE, "a/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE, }, stepper.last_feed_types()) # The handle to tensor e:0 should have been cached, even though its # transitive closure contains an override. self.assertIn("e:0", stepper.handle_names()) self.assertSetEqual({"e"}, stepper.handle_node_names()) # Remove the override. stepper.remove_override("c:0") # c:0 should not be in the overrides anymore. self.assertEqual([], stepper.override_names()) # Removing the override should have invalidated the tensor handle for c. self.assertNotIn("e:0", stepper.handle_names()) self.assertNotIn("e", stepper.handle_node_names()) # Should reflect the non-overriding value. self.assertAllClose(24.0, stepper.cont(self.e)) # This time, the handle to tensor e:0 should have been cached again, even # thought its transitive closure contains an override. self.assertIn("e:0", stepper.handle_names()) self.assertIn("e", stepper.handle_node_names()) # Calling cont(self.e) again should have used the tensor handle to e:0. self.assertAllClose(24.0, stepper.cont(self.e)) self.assertEqual({ "e:0": NodeStepper.FEED_TYPE_HANDLE, }, stepper.last_feed_types())
def testRemovingOverrideToUpstreamTensorInvalidatesDumpedIntermediates(self): with NodeStepper(self.sess, self.q) as stepper: stepper.override_tensor("v/read:0", 9.0) self.assertItemsEqual(["v/read:0"], stepper.override_names()) self.assertAllClose(324.0, stepper.cont(self.q)) self.assertItemsEqual(["p:0"], stepper.intermediate_tensor_names()) stepper.remove_override("v/read:0") self.assertItemsEqual([], stepper.override_names()) # Removing the pre-existing override to v/read:0 should have invalidated # the dumped intermediate tensor. self.assertItemsEqual([], stepper.intermediate_tensor_names())
def testFinalizeWithPreviousOverrides(self): with NodeStepper(self.sess, self.e) as stepper: stepper.override_tensor("a/read:0", 20.0) self.assertEqual(["a/read:0"], stepper.override_names()) # Should reflect the overriding value. self.assertAllClose(24000.0, stepper.cont("e:0")) self.assertEqual({ "a/read:0": NodeStepper.FEED_TYPE_OVERRIDE }, stepper.last_feed_types()) # Finalize call should have ignored the overriding value. self.assertAllClose(24.0, stepper.finalize())
def testContToUpdateB(self): stepper = NodeStepper(self.sess, "optim") result = stepper.cont("optim/update_b/ApplyGradientDescent") self.assertIsNone(result) self.assertEqual(set(["b:0"]), stepper.dirty_variables()) # For backprop on Variable b: # Because f = a * b * b * c, df / da = 2 * a * b * c. # 2.0 - learning_rate * 2 * a * b * c # = 2.0 - 0.01 * 2 * 1.0 * 2.0 * 4.0 = 1.84 self.assertAllClose(1.0, self.sess.run(self.a)) self.assertAllClose(1.84, self.sess.run(self.b)) self.assertAllClose(4.0, self.sess.run(self.c))
def testUsingNamesNotUsingIntermediateTensors(self): stepper = NodeStepper(self.sess, "e:0") # The first cont() call should have used no feeds. result = stepper.cont("c:0") self.assertAllClose(6.0, result) self.assertEqual({}, stepper.last_feed_types()) # The second cont() call should have used the tensor handle from the # previous cont() call. result = stepper.cont("e:0") self.assertAllClose(24.0, result) self.assertEqual({ "c:0": NodeStepper.FEED_TYPE_HANDLE }, stepper.last_feed_types())
def testGetTensorValueWorksOnPlaceholder(self): with NodeStepper( self.sess, self.y, feed_dict={ self.ph0: [[1.0, 2.0], [-3.0, 5.0]], self.ph1: [[-1.0], [0.5]] }) as stepper: self.assertAllClose([[1.0, 2.0], [-3.0, 5.0]], stepper.get_tensor_value("ph0")) self.assertAllClose([[1.0, 2.0], [-3.0, 5.0]], stepper.get_tensor_value("ph0:0")) with self.assertRaisesRegexp( KeyError, r"The name 'ph0:1' refers to a Tensor which does not exist"): stepper.get_tensor_value("ph0:1")
def testFinalize(self): """Test finalize() to restore variables and run the original fetch.""" stepper = NodeStepper(self.sess, "optim") # Invoke update_b before calling finalize. stepper.cont( "optim/update_b/ApplyGradientDescent", restore_variable_values=True) result = stepper.finalize() self.assertIsNone(result) # The results of the Variable updates should be the same as if no cont() # call has occurred on update_b. self.assertAllClose(0.84, self.sess.run(self.a)) self.assertAllClose(1.84, self.sess.run(self.b)) self.assertAllClose(3.96, self.sess.run(self.c))
def testContToTensorWithIntermediateDumpShouldUseDump(self): with NodeStepper(self.sess, ["e:0", "f:0"]) as stepper: stepper.cont("c:0") self.assertItemsEqual(["a/read:0", "b/read:0"], stepper.intermediate_tensor_names()) self.assertAllClose(2.0, stepper.get_tensor_value("a/read:0")) self.assertAllClose(3.0, stepper.get_tensor_value("b/read:0")) self.assertAllClose(2.0, stepper.cont("a/read:0")) self.assertEqual({ "a/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE }, stepper.last_feed_types()) self.assertAllClose(10.0, stepper.cont("f:0")) self.assertEqual({ "b/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE }, stepper.last_feed_types())
def testSelectiveHandleUsageDependingOnTransitiveCleanliness(self): """Test tensor handlers are using only during clean transitive closure. "clean" means no Variables have been updated by preceding cont() calls. """ stepper = NodeStepper(self.sess, "optim") # First, call cont() on the two tensors on the intermediate level: e and f. result = stepper.cont("d:0") self.assertAllClose(2.0, result) self.assertEqual({}, stepper.last_feed_types()) self.assertEqual(set(), stepper.dirty_variables()) # The cont call above should have restored Variable "b". result = stepper.cont("e:0") self.assertAllClose(8.0, result) self.assertEqual({}, stepper.last_feed_types()) self.assertEqual(set(), stepper.dirty_variables()) # Now run update_a, so as to let Variable a be diry. result = stepper.cont( "optim/update_a/ApplyGradientDescent", restore_variable_values=True) self.assertIsNone(result) self.assertEqual({"a:0"}, stepper.dirty_variables()) # Now, run update_b. result = stepper.cont( "optim/update_b/ApplyGradientDescent", restore_variable_values=True) self.assertIsNone(result) # The last cont() run should have use the handle of tensor e, but not the # handle of tensor d, because the transitive closure of e is clean, whereas # that of d is dirty due to the update to a in the previous cont() call. self.assertEqual({ "e:0": NodeStepper.FEED_TYPE_HANDLE }, stepper.last_feed_types()) # The result of the update_b should be identical to as if no other # update_* cont() calls have occurred before. self.assertAllClose(1.0, self.sess.run(self.a)) self.assertAllClose(1.84, self.sess.run(self.b)) self.assertAllClose(4.0, self.sess.run(self.c))