def testUsingNodesNotUsingIntermediateTensors(self): stepper = NodeStepper(self.sess, self.e) # There should be no handles before any cont() calls. self.assertEqual([], stepper.handle_names()) self.assertSetEqual(set(), stepper.handle_node_names()) # Before the cont() call, the stepper should not have access to the value # of c:0. with self.assertRaisesRegexp( ValueError, "This stepper instance does not have access to the value of tensor " "\"c:0\""): stepper.get_tensor_value("c:0") # Using the node/tensor itself, instead of the name str, should work on # cont(). result = stepper.cont(self.c) self.assertAllClose(6.0, result) self.assertEqual({}, stepper.last_feed_types()) self.assertEqual(["c:0"], stepper.handle_names()) self.assertEqual({"c"}, stepper.handle_node_names()) # After the cont() call, the stepper should have access to the value of c:0 # via a tensor handle. self.assertAllClose(6.0, stepper.get_tensor_value("c:0")) result = stepper.cont(self.e) self.assertAllClose(24.0, result) self.assertEqual({"c:0": NodeStepper.FEED_TYPE_HANDLE}, stepper.last_feed_types())
def testUsingNodesNotUsingIntermediateTensors(self): stepper = NodeStepper(self.sess, self.e) # There should be no handles before any cont() calls. self.assertEqual([], stepper.handle_names()) # Before the cont() call, the stepper should not have access to the value # of c:0. with self.assertRaisesRegexp( ValueError, "This stepper instance does not have access to the value of tensor " "\"c:0\""): stepper.get_tensor_value("c:0") # Using the node/tensor itself, instead of the name str, should work on # cont(). result = stepper.cont(self.c) self.assertAllClose(6.0, result) self.assertEqual({}, stepper.last_feed_types()) self.assertEqual(["c:0"], stepper.handle_names()) # After the cont() call, the stepper should have access to the value of c:0 # via a tensor handle. self.assertAllClose(6.0, stepper.get_tensor_value("c:0")) result = stepper.cont(self.e) self.assertAllClose(24.0, result) self.assertEqual({ "c:0": NodeStepper.FEED_TYPE_HANDLE }, stepper.last_feed_types())
def testGetTensorValueWorksOnPlaceholder(self): stepper = NodeStepper(self.sess, self.y, feed_dict={ self.ph0: [[1.0, 2.0], [-3.0, 5.0]], self.ph1: [[-1.0], [0.5]] }) self.assertAllClose([[1.0, 2.0], [-3.0, 5.0]], stepper.get_tensor_value("ph0")) self.assertAllClose([[1.0, 2.0], [-3.0, 5.0]], stepper.get_tensor_value("ph0:0")) with self.assertRaisesRegexp( KeyError, r"The name 'ph0:1' refers to a Tensor which does not exist"): stepper.get_tensor_value("ph0:1")
def testGetTensorValueWorksOnPlaceholder(self): stepper = NodeStepper( self.sess, self.y, feed_dict={ self.ph0: [[1.0, 2.0], [-3.0, 5.0]], self.ph1: [[-1.0], [0.5]] }) self.assertAllClose([[1.0, 2.0], [-3.0, 5.0]], stepper.get_tensor_value("ph0")) self.assertAllClose([[1.0, 2.0], [-3.0, 5.0]], stepper.get_tensor_value("ph0:0")) with self.assertRaisesRegexp( KeyError, r"The name 'ph0:1' refers to a Tensor which does not exist"): stepper.get_tensor_value("ph0:1")
def testOverrideValue(self): stepper = NodeStepper(self.sess, self.e) result = stepper.cont(self.c) self.assertAllClose(6.0, result) self.assertEqual({}, stepper.last_feed_types()) # There should be no overrides before any cont() calls. self.assertEqual([], stepper.override_names()) # Calling cont() on c again should lead to use of the handle. result = stepper.cont(self.c) self.assertAllClose(6.0, result) self.assertEqual({ "c:0": NodeStepper.FEED_TYPE_HANDLE }, stepper.last_feed_types()) # Override c:0. stepper.override_tensor("c:0", 7.0) # After the overriding, calling get_tensor_value() on c:0 should yield the # overriding value. self.assertEqual(7.0, stepper.get_tensor_value("c:0")) # Now c:0 should have only an override value, but no cached handle, because # the handle should have been invalidated. self.assertEqual([], stepper.handle_names()) self.assertSetEqual(set(), stepper.handle_node_names()) self.assertEqual(["c:0"], stepper.override_names()) # Run a downstream tensor after the value override. result = stepper.cont(self.e) self.assertAllClose(28.0, result) # Should reflect the overriding value. # Should use override, instead of the handle. self.assertEqual({ "c:0": NodeStepper.FEED_TYPE_OVERRIDE }, stepper.last_feed_types())
def testOverrideValue(self): stepper = NodeStepper(self.sess, self.e) result = stepper.cont(self.c) self.assertAllClose(6.0, result) self.assertEqual({}, stepper.last_feed_types()) # There should be no overrides before any cont() calls. self.assertEqual([], stepper.override_names()) # Calling cont() on c again should lead to use of the handle. result = stepper.cont(self.c) self.assertAllClose(6.0, result) self.assertEqual({ "c:0": NodeStepper.FEED_TYPE_HANDLE }, stepper.last_feed_types()) # Override c:0. stepper.override_tensor("c:0", 7.0) # After the overriding, calling get_tensor_value() on c:0 should yield the # overriding value. self.assertEqual(7.0, stepper.get_tensor_value("c:0")) # Now c:0 should have only an override value, but no cached handle, because # the handle should have been invalidated. self.assertEqual([], stepper.handle_names()) self.assertSetEqual(set(), stepper.handle_node_names()) self.assertEqual(["c:0"], stepper.override_names()) # Run a downstream tensor after the value override. result = stepper.cont(self.e) self.assertAllClose(28.0, result) # Should reflect the overriding value. # Should use override, instead of the handle. self.assertEqual({ "c:0": NodeStepper.FEED_TYPE_OVERRIDE }, stepper.last_feed_types())