def testOverrideAndContToSameTensor(self): stepper = NodeStepper(self.sess, self.e) result = stepper.cont(self.c) self.assertAllClose(6.0, result) self.assertEqual({}, stepper.last_feed_types()) self.assertEqual(["c:0"], stepper.handle_names()) self.assertSetEqual({"c"}, stepper.handle_node_names()) self.assertAllClose(6.0, stepper.cont(self.c)) # The last cont() call should use the tensor handle directly. self.assertEqual({ "c:0": NodeStepper.FEED_TYPE_HANDLE }, stepper.last_feed_types()) # Override c:0. stepper.override_tensor("c:0", 7.0) # As a result of the override, the tensor handle should have been # invalidated. self.assertEqual([], stepper.handle_names()) self.assertSetEqual(set(), stepper.handle_node_names()) result = stepper.cont(self.c) self.assertAllClose(7.0, result) self.assertEqual({ "c:0": NodeStepper.FEED_TYPE_OVERRIDE }, stepper.last_feed_types())
def testOverrideValueTwice(self): stepper = NodeStepper(self.sess, self.e) # Override once. stepper.override_tensor("c:0", 7.0) self.assertAllClose(28.0, stepper.cont(self.e)) self.assertEqual({ "c:0": NodeStepper.FEED_TYPE_OVERRIDE }, stepper.last_feed_types()) self.assertEqual(["e:0"], stepper.handle_names()) self.assertSetEqual({"e"}, stepper.handle_node_names()) self.assertEqual(["c:0"], stepper.override_names()) # Calling cont(self.e) again. This time the cached tensor handle of e # should be used. self.assertEqual(28.0, stepper.cont(self.e)) self.assertEqual({ "e:0": NodeStepper.FEED_TYPE_HANDLE }, stepper.last_feed_types()) # Override c again. This should have invalidated the cache for e. stepper.override_tensor("c:0", 8.0) self.assertEqual([], stepper.handle_names()) self.assertEqual(set(), stepper.handle_node_names()) self.assertEqual(["c:0"], stepper.override_names()) self.assertAllClose(32.0, stepper.cont(self.e)) self.assertEqual({ "c:0": NodeStepper.FEED_TYPE_OVERRIDE }, stepper.last_feed_types())
def testRemoveOverrideValue(self): stepper = NodeStepper(self.sess, self.e) result = stepper.cont(self.c) self.assertAllClose(6.0, result) self.assertEqual({}, stepper.last_feed_types()) # The previous cont() step should have generated a cached tensor handle. self.assertEqual(["c:0"], stepper.handle_names()) self.assertSetEqual({"c"}, stepper.handle_node_names()) # Override c:0. stepper.override_tensor("c:0", 7.0) # The overriding should have invalidated the tensor handle. self.assertEqual([], stepper.handle_names()) self.assertSetEqual(set(), stepper.handle_node_names()) self.assertEqual(["c:0"], stepper.override_names()) result = stepper.cont(self.e) self.assertAllClose(28.0, result) # Should reflect the overriding value. self.assertEqual({ "c:0": NodeStepper.FEED_TYPE_OVERRIDE }, stepper.last_feed_types()) # The handle to tensor e:0 should have been cached, even though its # transitive closure contains an override. self.assertIn("e:0", stepper.handle_names()) self.assertSetEqual({"e"}, stepper.handle_node_names()) # Remove the override. stepper.remove_override("c:0") # c:0 should not be in the overrides anymore. self.assertEqual([], stepper.override_names()) # Removing the override should have invalidated the tensor handle for c. self.assertNotIn("e:0", stepper.handle_names()) self.assertNotIn("e", stepper.handle_node_names()) # Should reflect the non-overriding value. self.assertAllClose(24.0, stepper.cont(self.e)) # This time, the handle to tensor e:0 should have been cached again, even # thought its transitive closure contains an override. self.assertIn("e:0", stepper.handle_names()) self.assertIn("e", stepper.handle_node_names()) # Calling cont(self.e) again should have used the tensor handle to e:0. self.assertAllClose(24.0, stepper.cont(self.e)) self.assertEqual({ "e:0": NodeStepper.FEED_TYPE_HANDLE }, stepper.last_feed_types())
def testFinalizeWithPreviousOverrides(self): stepper = NodeStepper(self.sess, self.e) stepper.override_tensor("a/read:0", 20.0) self.assertEqual(["a/read:0"], stepper.override_names()) # Should reflect the overriding value. self.assertAllClose(24000.0, stepper.cont("e:0")) self.assertEqual({ "a/read:0": NodeStepper.FEED_TYPE_OVERRIDE }, stepper.last_feed_types()) # Finalize call should have ignored the overriding value. self.assertAllClose(24.0, stepper.finalize())
def testOverrideValue(self): stepper = NodeStepper(self.sess, self.e) result = stepper.cont(self.c) self.assertAllClose(6.0, result) self.assertEqual({}, stepper.last_feed_types()) # There should be no overrides before any cont() calls. self.assertEqual([], stepper.override_names()) # Calling cont() on c again should lead to use of the handle. result = stepper.cont(self.c) self.assertAllClose(6.0, result) self.assertEqual({ "c:0": NodeStepper.FEED_TYPE_HANDLE }, stepper.last_feed_types()) # Override c:0. stepper.override_tensor("c:0", 7.0) # After the overriding, calling get_tensor_value() on c:0 should yield the # overriding value. self.assertEqual(7.0, stepper.get_tensor_value("c:0")) # Now c:0 should have only an override value, but no cached handle, because # the handle should have been invalidated. self.assertEqual([], stepper.handle_names()) self.assertSetEqual(set(), stepper.handle_node_names()) self.assertEqual(["c:0"], stepper.override_names()) # Run a downstream tensor after the value override. result = stepper.cont(self.e) self.assertAllClose(28.0, result) # Should reflect the overriding value. # Should use override, instead of the handle. self.assertEqual({ "c:0": NodeStepper.FEED_TYPE_OVERRIDE }, stepper.last_feed_types())
def testAttemptToOverrideInvalidTensor(self): stepper = NodeStepper(self.sess, self.e) with self.assertRaisesRegexp(ValueError, "Cannot override tensor \"f:0\""): stepper.override_tensor("f:0", 42.0)
def testOverrideThenContToUpdate(self): """Test cont() to update nodes after overriding tensor values.""" stepper = NodeStepper(self.sess, "optim") result = stepper.cont("d:0") self.assertAllClose(2.0, result) self.assertEqual({}, stepper.last_feed_types()) self.assertEqual(set(), stepper.dirty_variables()) self.assertEqual(["d:0"], stepper.handle_names()) self.assertSetEqual({"d"}, stepper.handle_node_names()) # Override the value from 1.0 to 10.0. stepper.override_tensor("a/read:0", 10.0) self.assertEqual(["a/read:0"], stepper.override_names()) result = stepper.cont( "optim/update_c/ApplyGradientDescent", restore_variable_values=True) self.assertIsNone(result) # The last cont() call should have not used the tensor handle to d:0, # because the transitive closure of d:0 contains an override tensor. self.assertEqual({ "a/read:0": NodeStepper.FEED_TYPE_OVERRIDE }, stepper.last_feed_types()) # The tensor handle to d:0 should have been removed due to the dirty # transitive closure. self.assertEqual([], stepper.handle_names()) self.assertSetEqual(set(), stepper.handle_node_names()) # For this backprop on c, the overriding value of a/read:0 should have been # used: # 4.0 - learning_rate * a * b * b # = 4.0 - 0.01 * 10.0 * 2.0 * 2.0 = 3.6. self.assertAllClose(3.6, self.sess.run(self.c)) # Now remove the overriding value of a/read:0. stepper.remove_override("a/read:0") self.assertEqual([], stepper.override_names()) # Obtain the tensor handle to d:0 again. result = stepper.cont("d:0") self.assertAllClose(2.0, result) self.assertEqual(["d:0"], stepper.handle_names()) self.assertSetEqual({"d"}, stepper.handle_node_names()) # Then call update_c again, without restoring c. result = stepper.cont( "optim/update_c/ApplyGradientDescent", restore_variable_values=False) self.assertIsNone(result) # This time, the d:0 tensor handle should have been used, because its # transitive closure is clean. self.assertEqual({ "d:0": NodeStepper.FEED_TYPE_HANDLE }, stepper.last_feed_types()) # For this backprop on c, the overriding value of a/read:0 should have been # used: # 3.6 - learning_rate * a * b * b # = 3.6 - 0.01 * 1.0 * 2.0 * 2.0 = 3.56. self.assertAllClose(3.56, self.sess.run(self.c))
def testInvalidOverrideArgumentType(self): stepper = NodeStepper(self.sess, self.e) with self.assertRaisesRegexp(TypeError, "Expected type str; got type"): stepper.override_tensor(self.a, 42.0)