def testContWithRestoreVariablesOptionShouldRestoreVariableValue(self): with stepper.NodeStepper(self.sess, self.opt) as node_stepper: cli = stepper_cli.NodeStepperCLI(node_stepper) output = cli.cont([ "opt/update_a/ApplyGradientDescent", "--invalidate_from_updated_variables" ]) self.assertItemsEqual([self.a.name], _parse_updated(output.lines)) # After cont() call on .../update_a/..., Variable a should have been # marked as dirty, whereas b should not have. output = cli.list_sorted_nodes([]) node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines) self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, stat_labels[node_names.index("a")]) self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, stat_labels[node_names.index("b")]) output = cli.cont( ["opt/update_b/ApplyGradientDescent", "-r", "-i"]) self.assertItemsEqual([self.b.name], _parse_updated(output.lines)) # After cont() call on .../update_b/... with the -r flag, Variable b # should have been marked as dirty, whereas Variable a should not be # because it should have been restored. output = cli.list_sorted_nodes([]) node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines) self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, stat_labels[node_names.index("b")]) self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, stat_labels[node_names.index("a")])
def testContToValidNodeShouldUpdateStatus(self): cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e)) output = cli.list_sorted_nodes([]) node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( output.lines) index_c = node_names.index("c") self.assertEqual(" ", stat_labels[index_c]) self.assertEqual(0, node_pointer) output = cli.cont("c") node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( output.lines) self.assertGreaterEqual(len(node_names), 3) self.assertIn("c", node_names) index_c = node_names.index("c") self.assertEqual(index_c, node_pointer) self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_c]) output = cli.cont("d") node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( output.lines) used_feed_types = _parsed_used_feeds(output.lines) self.assertEqual({"c:0": "handle"}, used_feed_types) self.assertGreaterEqual(len(node_names), 3) self.assertIn("d", node_names) index_d = node_names.index("d") self.assertEqual(index_d, node_pointer) self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_d])
def testContToNodeWithoutOutputTensorInClosureShowsNoHandleCached(self): node_stepper = stepper.NodeStepper(self.sess, self.opt) sorted_nodes = node_stepper.sorted_nodes() closure_elements = node_stepper.closure_elements() # Find a node which is in the list of sorted nodes, but whose output tensor # is not in the transitive closure. no_output_node = None for node in sorted_nodes: if (node + ":0" not in closure_elements and node + ":1" not in closure_elements): no_output_node = node break self.assertIsNotNone(no_output_node) cli = stepper_cli.NodeStepperCLI(node_stepper) output = cli.cont([no_output_node]) node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( output.lines) self.assertEqual(no_output_node, node_names[node_pointer]) self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[node_pointer])
def testContWithRestoreVariablesOptionShouldRestoreVariableValue(self): cli = stepper_cli.NodeStepperCLI( stepper.NodeStepper(self.sess, self.opt)) output = cli.cont(["opt/update_a/ApplyGradientDescent"]) # After cont() call on .../update_a/..., Variable a should have been marked # as dirty, whereas b should not have. output = cli.list_sorted_nodes([]) node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines) self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, stat_labels[node_names.index("a")]) self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, stat_labels[node_names.index("b")]) output = cli.cont(["opt/update_b/ApplyGradientDescent", "-r"]) # After cont() call on .../update_b/... with the -r flag, Variable b should # have been marked as dirty, whereas Variable a should not be because it # should have been restored. output = cli.list_sorted_nodes([]) node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines) self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, stat_labels[node_names.index("b")]) self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, stat_labels[node_names.index("a")])
def testInjectTensorValueByTensorNameShouldBeReflected(self): node_stepper = stepper.NodeStepper(self.sess, self.e) cli = stepper_cli.NodeStepperCLI(node_stepper) output = cli.cont(["d"]) node_names, _, node_pointer = _parse_sorted_nodes_list(output.lines) self.assertEqual("d", node_names[node_pointer]) output = cli.list_sorted_nodes([]) node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( output.lines) index_d = node_names.index("d") self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_d]) self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_OVERRIDDEN, stat_labels[index_d]) self.assertAllClose(-20.0, node_stepper.get_tensor_value("d:0")) output = cli.inject_value(["d:0", "20.0"]) # Verify that the override is available. self.assertEqual(["d:0"], node_stepper.override_names()) # Verify that the list of sorted nodes reflects the existence of the value # override (i.e., injection). output = cli.list_sorted_nodes([]) node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( output.lines) index_d = node_names.index("d") self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_d]) self.assertIn(stepper_cli.NodeStepperCLI.STATE_OVERRIDDEN, stat_labels[index_d])
def testContToNodeOutsideTransitiveClosureShouldError(self): cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e)) output = cli.cont(["f"]) self.assertEqual([ "ERROR: f is not in the transitive closure of this stepper " "instance." ], output.lines)
def testPrintTensorWithNonexistentTensorShouldError(self): cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e)) output = cli.print_tensor(["foobar"]) self.assertEqual([ "ERROR: foobar is not in the transitive closure of this stepper " "instance." ], output.lines)
def testPrintTensorWithNoHandleShouldError(self): cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e)) output = cli.print_tensor("e") self.assertEqual([ "This stepper instance does not have access to the value of tensor " "\"e:0\"" ], output.lines)
def testPrintTensorShouldWorkWithNodeNameWithOutputTensor(self): cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e)) cli.cont("d") output = cli.print_tensor(["d"]) self.assertEqual("Tensor \"d:0\":", output.lines[0]) self.assertEqual("-20.0", output.lines[-1])
def testContToNonexistentNodeShouldError(self): cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.f)) output = cli.cont(["foobar"]) self.assertEqual([ "ERROR: foobar is not in the transitive closure of this stepper " "instance." ], output.lines)
def testCallingInvokeNodeStepperOnDumpingWrapperRaisesException(self): sess = dumping_wrapper.DumpingDebugWrapperSession( self.sess, session_root=self.session_root, log_usage=False) node_stepper = stepper.NodeStepper(self.sess, self.inc_v) with self.assertRaisesRegexp( NotImplementedError, r"NonInteractiveDebugWrapperSession does not support node-stepper " r"mode\."): sess.invoke_node_stepper(node_stepper)
def testPrintTensorShouldWorkWithTensorName(self): with stepper.NodeStepper(self.sess, self.e) as node_stepper: cli = stepper_cli.NodeStepperCLI(node_stepper) cli.cont("d") output = cli.print_tensor(["d:0"]) self.assertEqual("Tensor \"d:0\":", output.lines[0]) self.assertEqual("-20.0", output.lines[-1])
def testInjectToNonexistentTensorShouldError(self): node_stepper = stepper.NodeStepper(self.sess, self.e) cli = stepper_cli.NodeStepperCLI(node_stepper) output = cli.inject_value(["foobar:0", "20.0"]) self.assertEqual([ "ERROR: foobar:0 is not in the transitive closure of this stepper " "instance." ], output.lines)
def testContToUpdateNodeLeadsToDirtyVariableLabel(self): cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.opt)) output = cli.cont(["opt/update_b/ApplyGradientDescent"]) output = cli.list_sorted_nodes([]) node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines) self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, stat_labels[node_names.index("b")]) self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, stat_labels[node_names.index("a")])
def testListingSortedNodesPresentsTransitveClosure(self): cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e)) output = cli.list_sorted_nodes([]) node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( output.lines) self._assert_nodes_topologically_sorted_with_target_e(node_names) self.assertEqual(len(node_names), len(stat_labels)) for stat_label in stat_labels: self.assertEqual(" ", stat_label) self.assertEqual(0, node_pointer)
def testPrintTensorShouldWorkSlicingString(self): ph_value = np.array([[1.0, 0.0], [0.0, 2.0]]) cli = stepper_cli.NodeStepperCLI( stepper.NodeStepper( self.sess, self.f, feed_dict={self.ph: ph_value})) output = cli.print_tensor(["ph:0[:, 1]"]) self.assertEqual("Tensor \"ph:0[:, 1]\":", output.lines[0]) self.assertEqual(repr(ph_value[:, 1]), output.lines[-1]) output = cli.print_tensor(["ph[:, 1]"]) self.assertEqual("Tensor \"ph:0[:, 1]\":", output.lines[0]) self.assertEqual(repr(ph_value[:, 1]), output.lines[-1])
def testContToUpdateNodeWithoutTrackingLeadsToNoDirtyVariableLabel(self): with stepper.NodeStepper(self.sess, self.opt) as node_stepper: cli = stepper_cli.NodeStepperCLI(node_stepper) output = cli.cont(["opt/update_b/ApplyGradientDescent"]) self.assertItemsEqual([self.b.name], _parse_updated(output.lines)) output = cli.list_sorted_nodes([]) node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines) self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, stat_labels[node_names.index("b")]) self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, stat_labels[node_names.index("a")])
def testSteppingMultipleStepsUpdatesStatus(self): cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e)) output = cli.list_sorted_nodes([]) orig_node_names, _, _ = _parse_sorted_nodes_list(output.lines) output = cli.step(["-t", "3"]) node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( output.lines) self.assertEqual(orig_node_names[2], node_names[node_pointer]) for i in xrange(node_pointer): self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[i]) for i in xrange(node_pointer + 1, len(stat_labels)): self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[i])
def before_run(self, run_context): if not self._wrapper_initialized: local_cli_wrapper.LocalCLIDebugWrapperSession.__init__( self, run_context.session, ui_type=self._ui_type) # Actually register tensor filters registered prior to the construction # of the underlying LocalCLIDebugWrapperSession object. for filter_name in self._pending_tensor_filters: local_cli_wrapper.LocalCLIDebugWrapperSession.add_tensor_filter( self, filter_name, self._pending_tensor_filters[filter_name]) self._wrapper_initialized = True # Increment run call counter. self._run_call_count += 1 # Adapt run_context to an instance of OnRunStartRequest for invoking # superclass on_run_start(). on_run_start_request = framework.OnRunStartRequest( run_context.original_args.fetches, run_context.original_args.feed_dict, None, None, self._run_call_count) on_run_start_response = self.on_run_start(on_run_start_request) self._performed_action = on_run_start_response.action run_args = session_run_hook.SessionRunArgs( None, feed_dict=None, options=config_pb2.RunOptions()) if self._performed_action == framework.OnRunStartAction.DEBUG_RUN: self._decorate_options_for_debug(run_args.options, run_context.session.graph) elif self._performed_action == framework.OnRunStartAction.INVOKE_STEPPER: # The _finalized property must be set to False so that the NodeStepper # can insert ops for retrieving TensorHandles. # pylint: disable=protected-access run_context.session.graph._finalized = False # pylint: enable=protected-access with stepper.NodeStepper( run_context.session, run_context.original_args.fetches, run_context.original_args.feed_dict) as node_stepper: self.invoke_node_stepper(node_stepper, restore_variable_values_on_exit=True) return run_args
def testContToValidNodeShouldUpdateStatus(self): with stepper.NodeStepper(self.sess, self.e) as node_stepper: cli = stepper_cli.NodeStepperCLI(node_stepper) output = cli.list_sorted_nodes([]) node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( output.lines) index_c = node_names.index("c") self.assertEqual(" ", stat_labels[index_c]) self.assertEqual(0, node_pointer) output = cli.cont("c") self.assertIsNone(_parse_updated(output.lines)) node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( output.lines) self.assertGreaterEqual(len(node_names), 3) self.assertIn("c", node_names) index_c = node_names.index("c") self.assertEqual(index_c, node_pointer) self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_c]) output = cli.cont("d") self.assertIsNone(_parse_updated(output.lines)) node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( output.lines) used_feed_types = _parsed_used_feeds(output.lines) self.assertEqual( { "c:0": stepper.NodeStepper.FEED_TYPE_HANDLE, "a/read:0": stepper.NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE, }, used_feed_types) self.assertGreaterEqual(len(node_names), 3) self.assertIn("d", node_names) index_d = node_names.index("d") self.assertEqual(index_d, node_pointer) self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_d])
def testListingSortedNodesLabelsPlaceholders(self): cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.f)) output = cli.list_sorted_nodes([]) node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( output.lines) self._assert_nodes_topologically_sorted_with_target_f(node_names) index_ph = node_names.index("ph") self.assertEqual(len(node_names), len(stat_labels)) for i in xrange(len(stat_labels)): if index_ph == i: self.assertIn(stepper_cli.NodeStepperCLI.STATE_IS_PLACEHOLDER, stat_labels[i]) else: self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_IS_PLACEHOLDER, stat_labels[i]) self.assertEqual(0, node_pointer)
def testSteppingOneStepAtATimeShouldUpdateStatus(self): with stepper.NodeStepper(self.sess, self.e) as node_stepper: cli = stepper_cli.NodeStepperCLI(node_stepper) output = cli.list_sorted_nodes([]) orig_node_names, _, node_pointer = _parse_sorted_nodes_list( output.lines) self.assertEqual(0, node_pointer) for i in xrange(len(orig_node_names)): output = cli.step([]) node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( output.lines) next_node_name = node_names[node_pointer] self.assertEqual(orig_node_names[i], next_node_name) self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[node_pointer]) # The order in which the nodes are listed should not change as the # stepping happens. output = cli.list_sorted_nodes([]) node_names, _, node_pointer = _parse_sorted_nodes_list( output.lines) self.assertEqual(orig_node_names, node_names) if i < len(orig_node_names) - 1: self.assertEqual(i + 1, node_pointer) else: # Stepped over the limit. Pointer should be at -1. self.assertEqual(-1, node_pointer) # Attempt to step once more after the end has been reached should error # out. output = cli.step([]) self.assertEqual([ "ERROR: Cannot step any further because the end of the sorted " "transitive closure has been reached." ], output.lines)
def run(self, fetches, feed_dict=None, options=None, run_metadata=None): """Wrapper around Session.run() that inserts tensor watch options. Args: fetches: Same as the `fetches` arg to regular `Session.run()`. feed_dict: Same as the `feed_dict` arg to regular `Session.run()`. options: Same as the `options` arg to regular `Session.run()`. run_metadata: Same as the `run_metadata` arg to regular `Session.run()`. Returns: Simply forwards the output of the wrapped `Session.run()` call. Raises: ValueError: On invalid `OnRunStartAction` value. """ self._run_call_count += 1 # Invoke on-run-start callback and obtain response. run_start_resp = self.on_run_start( OnRunStartRequest(fetches, feed_dict, options, run_metadata, self._run_call_count)) _check_type(run_start_resp, OnRunStartResponse) if run_start_resp.action == OnRunStartAction.DEBUG_RUN: # Decorate RunOption to fill in debugger tensor watch specifications. decorated_run_options = options or config_pb2.RunOptions() run_metadata = run_metadata or config_pb2.RunMetadata() self._decorate_run_options( decorated_run_options, run_start_resp.debug_urls, debug_ops=run_start_resp.debug_ops, node_name_regex_whitelist=run_start_resp. node_name_regex_whitelist, op_type_regex_whitelist=run_start_resp.op_type_regex_whitelist) # Invoke the run() method of the wrapped Session. Catch any TensorFlow # runtime errors. tf_error = None try: retvals = self._sess.run(fetches, feed_dict=feed_dict, options=decorated_run_options, run_metadata=run_metadata) except errors.OpError as op_error: tf_error = op_error retvals = op_error run_end_req = OnRunEndRequest( run_start_resp.action, run_metadata=run_metadata, client_graph_def=self._sess.graph.as_graph_def(), tf_error=tf_error) elif (run_start_resp.action == OnRunStartAction.NON_DEBUG_RUN or run_start_resp.action == OnRunStartAction.INVOKE_STEPPER): if run_start_resp.action == OnRunStartAction.INVOKE_STEPPER: retvals = self.invoke_node_stepper( stepper.NodeStepper(self._sess, fetches, feed_dict), restore_variable_values_on_exit=True) # Invoke run() method of the wrapped session. retvals = self._sess.run(fetches, feed_dict=feed_dict, options=options, run_metadata=run_metadata) # Prepare arg for the on-run-end callback. run_end_req = OnRunEndRequest(run_start_resp.action) else: raise ValueError("Invalid OnRunStartAction value: %s" % run_start_resp.action) # Invoke on-run-end callback and obtain response. run_end_resp = self.on_run_end(run_end_req) _check_type(run_end_resp, OnRunEndResponse) # Currently run_end_resp is only a placeholder. No action is taken on it. return retvals
def testInjectTensorValueByNodeNameShouldBeReflected(self): node_stepper = stepper.NodeStepper(self.sess, self.e) cli = stepper_cli.NodeStepperCLI(node_stepper) cli.inject_value(["d", "20.0"]) self.assertEqual(["d:0"], node_stepper.override_names())