def testNodeStepperConstructorShouldAllowListOrTupleOrDictOfFetches(self): for i in range(6): if i == 0: fetches = [self.e, [self.f, self.z]] elif i == 1: fetches = (self.e, (self.f, self.z)) elif i == 2: fetches = {"e": self.e, "fz": {"f": self.f, "z": self.z}} elif i == 3: fetches = ["e:0", ["f:0", "z:0"]] elif i == 4: fetches = ("e:0", ("f:0", "z:0")) elif i == 5: fetches = {"e": "e:0", "fz": {"f": "f:0", "z": "z:0"}} stepper = NodeStepper(self.sess, fetches) sorted_nodes = stepper.sorted_nodes() self.assertEqual(13, len(sorted_nodes)) # Check the topological order of the sorted nodes. self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("x/read")) self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("y")) self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("z")) self.assertLess(sorted_nodes.index("y"), sorted_nodes.index("z")) self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("a/read")) self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("b/read")) self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("c")) self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("c")) self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("d")) self.assertLess(sorted_nodes.index("d"), sorted_nodes.index("e")) self.assertLess(sorted_nodes.index("c"), sorted_nodes.index("e")) self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("f")) self.assertLess(sorted_nodes.index("f_y"), sorted_nodes.index("f")) closure_elements = stepper.closure_elements() self.assertIn("x/read:0", closure_elements) self.assertIn("e:0", closure_elements) self.assertIn("f:0", closure_elements) self.assertEqual([0], stepper.output_slots_in_closure("x/read")) self.assertEqual([0], stepper.output_slots_in_closure("e")) self.assertEqual([0], stepper.output_slots_in_closure("f")) result = stepper.finalize() if i == 0 or i == 1 or i == 3 or i == 4: self.assertAllClose(24.0, result[0]) self.assertAllClose(10.0, result[1][0]) self.assertAllClose(-4.0, result[1][1]) elif i == 2 or i == 5: self.assertAllClose(24.0, result["e"]) self.assertAllClose(10.0, result["fz"]["f"]) self.assertAllClose(-4.0, result["fz"]["z"])
def testNodeStepperConstructorShouldAllowListOrTupleOrDictOfFetches(self): for i in range(6): if i == 0: fetches = [self.e, [self.f, self.z]] elif i == 1: fetches = (self.e, (self.f, self.z)) elif i == 2: fetches = {"e": self.e, "fz": {"f": self.f, "z": self.z}} elif i == 3: fetches = ["e:0", ["f:0", "z:0"]] elif i == 4: fetches = ("e:0", ("f:0", "z:0")) elif i == 5: fetches = {"e": "e:0", "fz": {"f": "f:0", "z": "z:0"}} stepper = NodeStepper(self.sess, fetches) sorted_nodes = stepper.sorted_nodes() self.assertEqual(13, len(sorted_nodes)) # Check the topological order of the sorted nodes. self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("x/read")) self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("y")) self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("z")) self.assertLess(sorted_nodes.index("y"), sorted_nodes.index("z")) self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("a/read")) self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("b/read")) self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("c")) self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("c")) self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("d")) self.assertLess(sorted_nodes.index("d"), sorted_nodes.index("e")) self.assertLess(sorted_nodes.index("c"), sorted_nodes.index("e")) self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("f")) self.assertLess(sorted_nodes.index("f_y"), sorted_nodes.index("f")) closure_elements = stepper.closure_elements() self.assertIn("x/read:0", closure_elements) self.assertIn("e:0", closure_elements) self.assertIn("f:0", closure_elements) self.assertEqual([0], stepper.output_slots_in_closure("x/read")) self.assertEqual([0], stepper.output_slots_in_closure("e")) self.assertEqual([0], stepper.output_slots_in_closure("f")) result = stepper.finalize() if i == 0 or i == 1 or i == 3 or i == 4: self.assertAllClose(24.0, result[0]) self.assertAllClose(10.0, result[1][0]) self.assertAllClose(-4.0, result[1][1]) elif i == 2 or i == 5: self.assertAllClose(24.0, result["e"]) self.assertAllClose(10.0, result["fz"]["f"]) self.assertAllClose(-4.0, result["fz"]["z"])
def testContToNodeWithOutputTensors(self): """cont() to an op should cache its output tensors if appropriate.""" stepper = NodeStepper(self.sess, "optim") # In the transitive closure of the stepper, look for an op of which the # output tensor also is in the transitive closure. # Do not assume a specific op, e.g., ""gradients/e_grad/Reshape_1", # because it may vary between builds. closure_elements = stepper.closure_elements() op_with_output_in_closure = None for element_name in closure_elements: if element_name + ":0" in closure_elements: op_with_output_in_closure = str(element_name) break self.assertEqual([0], stepper.output_slots_in_closure(op_with_output_in_closure)) self.assertIsNotNone(op_with_output_in_closure) output_tensor = op_with_output_in_closure + ":0" # The op "gradients/?_grad/Reshape_1" is in the transitive closure of the # stepper, because it is the control input to another o. However, its # output tensor "gradients/?_grad/Reshape_1:0" is also in the transitive # closure, because it is the (non-control) input of certain ops. Calling # cont() on the op should lead to the caching of the tensor handle for # the output tensor. stepper.cont(op_with_output_in_closure) self.assertEqual([output_tensor], stepper.handle_names()) self.assertSetEqual({op_with_output_in_closure}, stepper.handle_node_names()) # Do a cont() call that uses the cached tensor of # "gradients/?_grad/Reshape_1:0". stepper.cont(output_tensor) self.assertEqual({ output_tensor: NodeStepper.FEED_TYPE_HANDLE }, stepper.last_feed_types())
def testContToNodeWithOutputTensors(self): """cont() to an op should cache its output tensors if appropriate.""" stepper = NodeStepper(self.sess, "optim") # In the transitive closure of the stepper, look for an op of which the # output tensor also is in the transitive closure. # Do not assume a specific op, e.g., ""gradients/e_grad/Reshape_1", # because it may vary between builds. closure_elements = stepper.closure_elements() op_with_output_in_closure = None for element_name in closure_elements: if element_name + ":0" in closure_elements: op_with_output_in_closure = str(element_name) break self.assertEqual([0], stepper.output_slots_in_closure(op_with_output_in_closure)) self.assertIsNotNone(op_with_output_in_closure) output_tensor = op_with_output_in_closure + ":0" # The op "gradients/?_grad/Reshape_1" is in the transitive closure of the # stepper, because it is the control input to another o. However, its # output tensor "gradients/?_grad/Reshape_1:0" is also in the transitive # closure, because it is the (non-control) input of certain ops. Calling # cont() on the op should lead to the caching of the tensor handle for # the output tensor. stepper.cont(op_with_output_in_closure) self.assertEqual([output_tensor], stepper.handle_names()) self.assertSetEqual({op_with_output_in_closure}, stepper.handle_node_names()) # Do a cont() call that uses the cached tensor of # "gradients/?_grad/Reshape_1:0". stepper.cont(output_tensor) self.assertEqual({ output_tensor: NodeStepper.FEED_TYPE_HANDLE }, stepper.last_feed_types())