def testContWithRestoreVariablesOptionShouldRestoreVariableValue(self): cli = stepper_cli.NodeStepperCLI( stepper.NodeStepper(self.sess, self.opt)) output = cli.cont(["opt/update_a/ApplyGradientDescent"]) # After cont() call on .../update_a/..., Variable a should have been marked # as dirty, whereas b should not have. output = cli.list_sorted_nodes([]) node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines) self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, stat_labels[node_names.index("a")]) self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, stat_labels[node_names.index("b")]) output = cli.cont(["opt/update_b/ApplyGradientDescent", "-r"]) # After cont() call on .../update_b/... with the -r flag, Variable b should # have been marked as dirty, whereas Variable a should not be because it # should have been restored. output = cli.list_sorted_nodes([]) node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines) self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, stat_labels[node_names.index("b")]) self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, stat_labels[node_names.index("a")])
def testContToNodeWithoutOutputTensorInClosureShowsNoHandleCached(self): with stepper.NodeStepper(self.sess, self.opt) as node_stepper: sorted_nodes = node_stepper.sorted_nodes() closure_elements = node_stepper.closure_elements() # Find a node which is in the list of sorted nodes, but whose output # Tensor is not in the transitive closure. no_output_node = None for node in sorted_nodes: if (node + ":0" not in closure_elements and node + ":1" not in closure_elements): no_output_node = node break self.assertIsNotNone(no_output_node) cli = stepper_cli.NodeStepperCLI(node_stepper) output = cli.cont([no_output_node]) self.assertIsNone(_parse_updated(output.lines)) node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( output.lines) self.assertEqual(no_output_node, node_names[node_pointer]) self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[node_pointer])
def testContWithRestoreVariablesOptionShouldRestoreVariableValue(self): with stepper.NodeStepper(self.sess, self.opt) as node_stepper: cli = stepper_cli.NodeStepperCLI(node_stepper) output = cli.cont(["opt/update_a/ApplyGradientDescent", "--invalidate_from_updated_variables"]) self.assertItemsEqual([self.a.name], _parse_updated(output.lines)) # After cont() call on .../update_a/..., Variable a should have been # marked as dirty, whereas b should not have. output = cli.list_sorted_nodes([]) node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines) self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, stat_labels[node_names.index("a")]) self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, stat_labels[node_names.index("b")]) output = cli.cont(["opt/update_b/ApplyGradientDescent", "-r", "-i"]) self.assertItemsEqual([self.b.name], _parse_updated(output.lines)) # After cont() call on .../update_b/... with the -r flag, Variable b # should have been marked as dirty, whereas Variable a should not be # because it should have been restored. output = cli.list_sorted_nodes([]) node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines) self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, stat_labels[node_names.index("b")]) self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, stat_labels[node_names.index("a")])
def testContToValidNodeShouldUpdateStatus(self): cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e)) output = cli.list_sorted_nodes([]) node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( output.lines) index_c = node_names.index("c") self.assertEqual(" ", stat_labels[index_c]) self.assertEqual(0, node_pointer) output = cli.cont("c") node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( output.lines) self.assertGreaterEqual(len(node_names), 3) self.assertIn("c", node_names) index_c = node_names.index("c") self.assertEqual(index_c, node_pointer) self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_c]) output = cli.cont("d") node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( output.lines) used_feed_types = _parsed_used_feeds(output.lines) self.assertEqual({"c:0": "handle"}, used_feed_types) self.assertGreaterEqual(len(node_names), 3) self.assertIn("d", node_names) index_d = node_names.index("d") self.assertEqual(index_d, node_pointer) self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_d])
def testInjectTensorValueByTensorNameShouldBeReflected(self): with stepper.NodeStepper(self.sess, self.e) as node_stepper: cli = stepper_cli.NodeStepperCLI(node_stepper) output = cli.cont(["d"]) node_names, _, node_pointer = _parse_sorted_nodes_list(output.lines) self.assertEqual("d", node_names[node_pointer]) output = cli.list_sorted_nodes([]) node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( output.lines) index_d = node_names.index("d") self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_d]) self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_OVERRIDDEN, stat_labels[index_d]) self.assertAllClose(-20.0, node_stepper.get_tensor_value("d:0")) output = cli.inject_value(["d:0", "20.0"]) # Verify that the override is available. self.assertEqual(["d:0"], node_stepper.override_names()) # Verify that the list of sorted nodes reflects the existence of the value # override (i.e., injection). output = cli.list_sorted_nodes([]) node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( output.lines) index_d = node_names.index("d") self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_d]) self.assertIn(stepper_cli.NodeStepperCLI.STATE_OVERRIDDEN, stat_labels[index_d])
def testContToNodeOutsideTransitiveClosureShouldError(self): cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e)) output = cli.cont(["f"]) self.assertEqual([ "ERROR: f is not in the transitive closure of this stepper " "instance." ], output.lines)
def testContToNonexistentNodeShouldError(self): cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.f)) output = cli.cont(["foobar"]) self.assertEqual([ "ERROR: foobar is not in the transitive closure of this stepper " "instance." ], output.lines)
def testPrintTensorWithNonexistentTensorShouldError(self): cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e)) output = cli.print_tensor(["foobar"]) self.assertEqual([ "ERROR: foobar is not in the transitive closure of this stepper " "instance." ], output.lines)
def testPrintTensorWithNoHandleShouldError(self): cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e)) output = cli.print_tensor("e") self.assertEqual([ "This stepper instance does not have access to the value of tensor " "\"e:0\"" ], output.lines)
def testPrintTensorShouldWorkWithNodeNameWithOutputTensor(self): cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e)) cli.cont("d") output = cli.print_tensor(["d"]) self.assertEqual("Tensor \"d:0\":", output.lines[0]) self.assertEqual("-20.0", output.lines[-1])
def testPrintTensorShouldWorkWithTensorName(self): with stepper.NodeStepper(self.sess, self.e) as node_stepper: cli = stepper_cli.NodeStepperCLI(node_stepper) cli.cont("d") output = cli.print_tensor(["d:0"]) self.assertEqual("Tensor \"d:0\":", output.lines[0]) self.assertEqual("-20.0", output.lines[-1])
def testInjectToNonexistentTensorShouldError(self): with stepper.NodeStepper(self.sess, self.e) as node_stepper: cli = stepper_cli.NodeStepperCLI(node_stepper) output = cli.inject_value(["foobar:0", "20.0"]) self.assertEqual([ "ERROR: foobar:0 is not in the transitive closure of this stepper " "instance." ], output.lines)
def testContToUpdateNodeLeadsToDirtyVariableLabel(self): cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.opt)) output = cli.cont(["opt/update_b/ApplyGradientDescent"]) output = cli.list_sorted_nodes([]) node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines) self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, stat_labels[node_names.index("b")]) self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, stat_labels[node_names.index("a")])
def testListingSortedNodesPresentsTransitveClosure(self): cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e)) output = cli.list_sorted_nodes([]) node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( output.lines) self._assert_nodes_topologically_sorted_with_target_e(node_names) self.assertEqual(len(node_names), len(stat_labels)) for stat_label in stat_labels: self.assertEqual(" ", stat_label) self.assertEqual(0, node_pointer)
def testPrintTensorShouldWorkSlicingString(self): ph_value = np.array([[1.0, 0.0], [0.0, 2.0]]) with stepper.NodeStepper( self.sess, self.f, feed_dict={self.ph: ph_value}) as node_stepper: cli = stepper_cli.NodeStepperCLI(node_stepper) output = cli.print_tensor(["ph:0[:, 1]"]) self.assertEqual("Tensor \"ph:0[:, 1]\":", output.lines[0]) self.assertEqual(repr(ph_value[:, 1]), output.lines[-1]) output = cli.print_tensor(["ph[:, 1]"]) self.assertEqual("Tensor \"ph:0[:, 1]\":", output.lines[0]) self.assertEqual(repr(ph_value[:, 1]), output.lines[-1])
def testContToUpdateNodeWithoutTrackingLeadsToNoDirtyVariableLabel(self): with stepper.NodeStepper(self.sess, self.opt) as node_stepper: cli = stepper_cli.NodeStepperCLI(node_stepper) output = cli.cont(["opt/update_b/ApplyGradientDescent"]) self.assertItemsEqual([self.b.name], _parse_updated(output.lines)) output = cli.list_sorted_nodes([]) node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines) self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, stat_labels[node_names.index("b")]) self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, stat_labels[node_names.index("a")])
def testSteppingMultipleStepsUpdatesStatus(self): cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e)) output = cli.list_sorted_nodes([]) orig_node_names, _, _ = _parse_sorted_nodes_list(output.lines) output = cli.step(["-t", "3"]) node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( output.lines) self.assertEqual(orig_node_names[2], node_names[node_pointer]) for i in xrange(node_pointer): self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[i]) for i in xrange(node_pointer + 1, len(stat_labels)): self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[i])
def testContToValidNodeShouldUpdateStatus(self): with stepper.NodeStepper(self.sess, self.e) as node_stepper: cli = stepper_cli.NodeStepperCLI(node_stepper) output = cli.list_sorted_nodes([]) node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( output.lines) index_c = node_names.index("c") self.assertEqual(" ", stat_labels[index_c]) self.assertEqual(0, node_pointer) output = cli.cont("c") self.assertIsNone(_parse_updated(output.lines)) node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( output.lines) self.assertGreaterEqual(len(node_names), 3) self.assertIn("c", node_names) index_c = node_names.index("c") self.assertEqual(index_c, node_pointer) self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_c]) output = cli.cont("d") self.assertIsNone(_parse_updated(output.lines)) node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( output.lines) used_feed_types = _parsed_used_feeds(output.lines) self.assertEqual( { "c:0": stepper.NodeStepper.FEED_TYPE_HANDLE, "a/read:0": stepper.NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE, }, used_feed_types) self.assertGreaterEqual(len(node_names), 3) self.assertIn("d", node_names) index_d = node_names.index("d") self.assertEqual(index_d, node_pointer) self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_d])
def testListingSortedNodesLabelsPlaceholders(self): cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.f)) output = cli.list_sorted_nodes([]) node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( output.lines) self._assert_nodes_topologically_sorted_with_target_f(node_names) index_ph = node_names.index("ph") self.assertEqual(len(node_names), len(stat_labels)) for i in xrange(len(stat_labels)): if index_ph == i: self.assertIn(stepper_cli.NodeStepperCLI.STATE_IS_PLACEHOLDER, stat_labels[i]) else: self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_IS_PLACEHOLDER, stat_labels[i]) self.assertEqual(0, node_pointer)
def testSteppingOneStepAtATimeShouldUpdateStatus(self): if test_util.is_gpu_available(): self.skipTest("b/123446705 this causes a segfault on GPU") with stepper.NodeStepper(self.sess, self.e) as node_stepper: cli = stepper_cli.NodeStepperCLI(node_stepper) output = cli.list_sorted_nodes([]) orig_node_names, _, node_pointer = _parse_sorted_nodes_list(output.lines) self.assertEqual(0, node_pointer) for i in xrange(len(orig_node_names)): output = cli.step([]) node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( output.lines) next_node_name = node_names[node_pointer] self.assertEqual(orig_node_names[i], next_node_name) self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[node_pointer]) # The order in which the nodes are listed should not change as the # stepping happens. output = cli.list_sorted_nodes([]) node_names, _, node_pointer = _parse_sorted_nodes_list(output.lines) self.assertEqual(orig_node_names, node_names) if i < len(orig_node_names) - 1: self.assertEqual(i + 1, node_pointer) else: # Stepped over the limit. Pointer should be at -1. self.assertEqual(-1, node_pointer) # Attempt to step once more after the end has been reached should error # out. output = cli.step([]) self.assertEqual([ "ERROR: Cannot step any further because the end of the sorted " "transitive closure has been reached." ], output.lines)
def invoke_node_stepper(self, node_stepper, restore_variable_values_on_exit=True): """Overrides method in base class to implement interactive node stepper. Args: node_stepper: (stepper.NodeStepper) The underlying NodeStepper API object. restore_variable_values_on_exit: (bool) Whether any variables whose values have been altered during this node-stepper invocation should be restored to their old values when this invocation ends. Returns: The same return values as the `Session.run()` call on the same fetches as the NodeStepper. """ stepper = stepper_cli.NodeStepperCLI(node_stepper) # On exiting the node-stepper CLI, the finalize method of the node_stepper # object will be called, ensuring that the state of the graph will be the # same as if the stepping did not happen. # TODO(cais): Perhaps some users will want the effect of the interactive # stepping and value injection to persist. When that happens, make the call # to finalize optional. stepper_ui = curses_ui.CursesUI( on_ui_exit=(node_stepper.restore_variable_values if restore_variable_values_on_exit else None)) stepper_ui.register_command_handler( "list_sorted_nodes", stepper.list_sorted_nodes, stepper.arg_parsers["list_sorted_nodes"].format_help(), prefix_aliases=["lt", "lsn"]) stepper_ui.register_command_handler( "cont", stepper.cont, stepper.arg_parsers["cont"].format_help(), prefix_aliases=["ct", "c"]) stepper_ui.register_command_handler( "step", stepper.step, stepper.arg_parsers["step"].format_help(), prefix_aliases=["st", "s"]) stepper_ui.register_command_handler( "print_tensor", stepper.print_tensor, stepper.arg_parsers["print_tensor"].format_help(), prefix_aliases=["pt"]) stepper_ui.register_command_handler( "inject_value", stepper.inject_value, stepper.arg_parsers["inject_value"].format_help(), prefix_aliases=["inject", "override_value", "override"]) # Register tab completion candidates. stepper_ui.register_tab_comp_context([ "cont", "ct", "c", "pt", "inject_value", "inject", "override_value", "override" ], [str(elem) for elem in node_stepper.sorted_nodes()]) # TODO(cais): Tie up register_tab_comp_context to a single alias to shorten # calls like this. return stepper_ui.run_ui( init_command="lt", title="Node Stepper: " + self._run_description, title_color="blue_on_white")
def testInjectTensorValueByNodeNameShouldBeReflected(self): with stepper.NodeStepper(self.sess, self.e) as node_stepper: cli = stepper_cli.NodeStepperCLI(node_stepper) cli.inject_value(["d", "20.0"]) self.assertEqual(["d:0"], node_stepper.override_names())