def testInjectTensorValueByTensorNameShouldBeReflected(self): with stepper.NodeStepper(self.sess, self.e) as node_stepper: cli = stepper_cli.NodeStepperCLI(node_stepper) output = cli.cont(["d"]) node_names, _, node_pointer = _parse_sorted_nodes_list( output.lines) self.assertEqual("d", node_names[node_pointer]) output = cli.list_sorted_nodes([]) node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( output.lines) index_d = node_names.index("d") self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_d]) self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_OVERRIDDEN, stat_labels[index_d]) self.assertAllClose(-20.0, node_stepper.get_tensor_value("d:0")) output = cli.inject_value(["d:0", "20.0"]) # Verify that the override is available. self.assertEqual(["d:0"], node_stepper.override_names()) # Verify that the list of sorted nodes reflects the existence of the value # override (i.e., injection). output = cli.list_sorted_nodes([]) node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( output.lines) index_d = node_names.index("d") self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_d]) self.assertIn(stepper_cli.NodeStepperCLI.STATE_OVERRIDDEN, stat_labels[index_d])
def testContWithRestoreVariablesOptionShouldRestoreVariableValue(self): cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.opt)) output = cli.cont(["opt/update_a/ApplyGradientDescent"]) # After cont() call on .../update_a/..., Variable a should have been marked # as dirty, whereas b should not have. output = cli.list_sorted_nodes([]) node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines) self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, stat_labels[node_names.index("a")]) self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, stat_labels[node_names.index("b")]) output = cli.cont(["opt/update_b/ApplyGradientDescent", "-r"]) # After cont() call on .../update_b/... with the -r flag, Variable b should # have been marked as dirty, whereas Variable a should not be because it # should have been restored. output = cli.list_sorted_nodes([]) node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines) self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, stat_labels[node_names.index("b")]) self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, stat_labels[node_names.index("a")])
def testSteppingOneStepAtATimeShouldUpdateStatus(self): cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e)) output = cli.list_sorted_nodes([]) orig_node_names, _, node_pointer = _parse_sorted_nodes_list(output.lines) self.assertEqual(0, node_pointer) for i in xrange(len(orig_node_names)): output = cli.step([]) node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( output.lines) next_node_name = node_names[node_pointer] self.assertEqual(orig_node_names[i], next_node_name) self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[node_pointer]) # The order in which the nodes are listed should not change as the # stepping happens. output = cli.list_sorted_nodes([]) node_names, _, node_pointer = _parse_sorted_nodes_list(output.lines) self.assertEqual(orig_node_names, node_names) if i < len(orig_node_names) - 1: self.assertEqual(i + 1, node_pointer) else: # Stepped over the limit. Pointer should be at -1. self.assertEqual(-1, node_pointer) # Attempt to step once more after the end has been reached should error out. output = cli.step([]) self.assertEqual([ "ERROR: Cannot step any further because the end of the sorted " "transitive closure has been reached." ], output.lines)
def before_run(self, run_context): if not self._session_wrapper: self._session_wrapper = local_cli_wrapper.LocalCLIDebugWrapperSession( run_context.session, ui_type=self._ui_type, dump_root=self._dump_root, thread_name_filter=self._thread_name_filter) # Actually register tensor filters registered prior to the construction # of the underlying LocalCLIDebugWrapperSession object. for filter_name in self._pending_tensor_filters: self._session_wrapper.add_tensor_filter( filter_name, self._pending_tensor_filters[filter_name]) # Increment run call counter. self._session_wrapper.increment_run_call_count() # Adapt run_context to an instance of OnRunStartRequest for invoking # superclass on_run_start(). on_run_start_request = framework.OnRunStartRequest( run_context.original_args.fetches, run_context.original_args.feed_dict, None, None, self._session_wrapper.run_call_count) on_run_start_response = self._session_wrapper.on_run_start( on_run_start_request) self._performed_action = on_run_start_response.action run_args = session_run_hook.SessionRunArgs( None, feed_dict=None, options=config_pb2.RunOptions()) if self._performed_action == framework.OnRunStartAction.DEBUG_RUN: # pylint: disable=protected-access self._session_wrapper._decorate_run_options_for_debug( run_args.options, on_run_start_response.debug_urls, debug_ops=on_run_start_response.debug_ops, node_name_regex_whitelist=( on_run_start_response.node_name_regex_whitelist), op_type_regex_whitelist=( on_run_start_response.op_type_regex_whitelist), tensor_dtype_regex_whitelist=( on_run_start_response.tensor_dtype_regex_whitelist), tolerate_debug_op_creation_failures=( on_run_start_response.tolerate_debug_op_creation_failures)) # pylint: enable=protected-access elif self._performed_action == framework.OnRunStartAction.PROFILE_RUN: # pylint: disable=protected-access self._session_wrapper._decorate_run_options_for_profile( run_args.options) # pylint: enable=protected-access elif self._performed_action == framework.OnRunStartAction.INVOKE_STEPPER: # The _finalized property must be set to False so that the NodeStepper # can insert ops for retrieving TensorHandles. # pylint: disable=protected-access run_context.session.graph._finalized = False # pylint: enable=protected-access with stepper.NodeStepper( run_context.session, run_context.original_args.fetches, run_context.original_args.feed_dict) as node_stepper: self._session_wrapper.invoke_node_stepper( node_stepper, restore_variable_values_on_exit=True) return run_args
def run(self, fetches, feed_dict=None, options=None, run_metadata=None): """Wrapper around Session.run() that inserts tensor watch options. Args: fetches: Same as the `fetches` arg to regular `Session.run()`. feed_dict: Same as the `feed_dict` arg to regular `Session.run()`. options: Same as the `options` arg to regular `Session.run()`. run_metadata: Same as the `run_metadata` arg to regular `Session.run()`. Returns: Simply forwards the output of the wrapped `Session.run()` call. Raises: ValueError: On invalid `OnRunStartAction` value. """ self._run_call_count += 1 # Invoke on-run-start callback and obtain response. run_start_resp = self.on_run_start( OnRunStartRequest(fetches, feed_dict, options, run_metadata, self._run_call_count)) _check_type(run_start_resp, OnRunStartResponse) if run_start_resp.action == OnRunStartAction.DEBUG_RUN: # Decorate RunOption to fill in debugger tensor watch specifications. decorated_run_options = options or config_pb2.RunOptions() run_metadata = run_metadata or config_pb2.RunMetadata() self._decorate_run_options( decorated_run_options, run_start_resp.debug_urls, debug_ops=run_start_resp.debug_ops, node_name_regex_whitelist=run_start_resp.node_name_regex_whitelist, op_type_regex_whitelist=run_start_resp.op_type_regex_whitelist, tensor_dtype_regex_whitelist=( run_start_resp.tensor_dtype_regex_whitelist), tolerate_debug_op_creation_failures=( run_start_resp.tolerate_debug_op_creation_failures)) # Invoke the run() method of the wrapped Session. Catch any TensorFlow # runtime errors. tf_error = None try: retvals = self._sess.run(fetches, feed_dict=feed_dict, options=decorated_run_options, run_metadata=run_metadata) except errors.OpError as op_error: tf_error = op_error retvals = op_error run_end_req = OnRunEndRequest( run_start_resp.action, run_metadata=run_metadata, client_graph_def=self._sess.graph.as_graph_def(), tf_error=tf_error) elif (run_start_resp.action == OnRunStartAction.NON_DEBUG_RUN or run_start_resp.action == OnRunStartAction.INVOKE_STEPPER): if run_start_resp.action == OnRunStartAction.INVOKE_STEPPER: with stepper.NodeStepper( self._sess, fetches, feed_dict) as node_stepper: retvals = self.invoke_node_stepper( node_stepper, restore_variable_values_on_exit=True) # Invoke run() method of the wrapped session. retvals = self._sess.run( fetches, feed_dict=feed_dict, options=options, run_metadata=run_metadata) # Prepare arg for the on-run-end callback. run_end_req = OnRunEndRequest(run_start_resp.action) else: raise ValueError( "Invalid OnRunStartAction value: %s" % run_start_resp.action) # Invoke on-run-end callback and obtain response. run_end_resp = self.on_run_end(run_end_req) _check_type(run_end_resp, OnRunEndResponse) # Currently run_end_resp is only a placeholder. No action is taken on it. return retvals
def run(self, fetches, feed_dict=None, options=None, run_metadata=None, callable_runner=None, callable_runner_args=None): """Wrapper around Session.run() that inserts tensor watch options. Args: fetches: Same as the `fetches` arg to regular `Session.run()`. feed_dict: Same as the `feed_dict` arg to regular `Session.run()`. options: Same as the `options` arg to regular `Session.run()`. run_metadata: Same as the `run_metadata` arg to regular `Session.run()`. callable_runner: A `callable` returned by `Session.make_callable()`. If not `None`, `fetches` and `feed_dict` must both be `None`. callable_runner_args: An optional list of arguments to `callable_runner`. Returns: Simply forwards the output of the wrapped `Session.run()` call. Raises: ValueError: On invalid `OnRunStartAction` value. Or if `callable_runner` is not `None` and either or both of `fetches` and `feed_dict` is `None`. """ if not callable_runner: self.increment_run_call_count() else: if fetches or feed_dict: raise ValueError( "callable_runner and fetches/feed_dict are mutually exclusive, but " "are used simultaneously.") empty_fetches = not nest.flatten(fetches) if empty_fetches: tf_logging.info( "Due to empty fetches, tfdbg Session wrapper is letting a " "Session.run pass through without any debugging actions.") if self._is_disabled_thread() or empty_fetches: if callable_runner: return callable_runner(*callable_runner_args) else: return self._sess.run(fetches, feed_dict=feed_dict, options=options, run_metadata=run_metadata) # Invoke on-run-start callback and obtain response. run_start_resp = self.on_run_start( OnRunStartRequest(fetches, feed_dict, options, run_metadata, self._run_call_count, is_callable_runner=bool(callable_runner))) _check_type(run_start_resp, OnRunStartResponse) if run_start_resp.action == OnRunStartAction.DEBUG_RUN: # Decorate RunOption to fill in debugger tensor watch specifications. decorated_run_options = options or config_pb2.RunOptions() run_metadata = run_metadata or config_pb2.RunMetadata() self._decorate_run_options_for_debug( decorated_run_options, run_start_resp.debug_urls, debug_ops=run_start_resp.debug_ops, node_name_regex_whitelist=run_start_resp.node_name_regex_whitelist, op_type_regex_whitelist=run_start_resp.op_type_regex_whitelist, tensor_dtype_regex_whitelist=( run_start_resp.tensor_dtype_regex_whitelist), tolerate_debug_op_creation_failures=( run_start_resp.tolerate_debug_op_creation_failures)) # Invoke the run() method of the wrapped Session. Catch any TensorFlow # runtime errors. tf_error = None try: if callable_runner: retvals = callable_runner(*callable_runner_args, options=decorated_run_options, run_metadata=run_metadata) else: retvals = self._sess.run(fetches, feed_dict=feed_dict, options=decorated_run_options, run_metadata=run_metadata) except errors.OpError as op_error: if self._pass_through_operrors: raise op_error tf_error = op_error retvals = op_error run_end_req = OnRunEndRequest( run_start_resp.action, run_metadata=run_metadata, client_graph_def=self._sess.graph.as_graph_def(), tf_error=tf_error) elif run_start_resp.action == OnRunStartAction.PROFILE_RUN: decorated_run_options = options or config_pb2.RunOptions() run_metadata = run_metadata or config_pb2.RunMetadata() self._decorate_run_options_for_profile(decorated_run_options) if callable_runner: retvals = callable_runner(*callable_runner_args, options=decorated_run_options, run_metadata=run_metadata) else: retvals = self._sess.run(fetches, feed_dict=feed_dict, options=decorated_run_options, run_metadata=run_metadata) run_end_req = OnRunEndRequest( run_start_resp.action, run_metadata=run_metadata, client_graph_def=self._sess.graph.as_graph_def()) elif (run_start_resp.action == OnRunStartAction.NON_DEBUG_RUN or run_start_resp.action == OnRunStartAction.INVOKE_STEPPER): if callable_runner: raise NotImplementedError( "Stepper mode is not implemented for callables created by " "Session.make_callable().") if run_start_resp.action == OnRunStartAction.INVOKE_STEPPER: with stepper.NodeStepper( self._sess, fetches, feed_dict) as node_stepper: retvals = self.invoke_node_stepper( node_stepper, restore_variable_values_on_exit=True) # Invoke run() method of the wrapped session. retvals = self._sess.run( fetches, feed_dict=feed_dict, options=options, run_metadata=run_metadata) # Prepare arg for the on-run-end callback. run_end_req = OnRunEndRequest(run_start_resp.action) else: raise ValueError( "Invalid OnRunStartAction value: %s" % run_start_resp.action) # Invoke on-run-end callback and obtain response. run_end_resp = self.on_run_end(run_end_req) _check_type(run_end_resp, OnRunEndResponse) # Currently run_end_resp is only a placeholder. No action is taken on it. return retvals
def testInjectTensorValueByNodeNameShouldBeReflected(self): with stepper.NodeStepper(self.sess, self.e) as node_stepper: cli = stepper_cli.NodeStepperCLI(node_stepper) cli.inject_value(["d", "20.0"]) self.assertEqual(["d:0"], node_stepper.override_names())