Exemple #1
0
    def testInjectTensorValueByTensorNameShouldBeReflected(self):
        with stepper.NodeStepper(self.sess, self.e) as node_stepper:
            cli = stepper_cli.NodeStepperCLI(node_stepper)

            output = cli.cont(["d"])
            node_names, _, node_pointer = _parse_sorted_nodes_list(
                output.lines)
            self.assertEqual("d", node_names[node_pointer])

            output = cli.list_sorted_nodes([])
            node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
                output.lines)

            index_d = node_names.index("d")
            self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT,
                          stat_labels[index_d])
            self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_OVERRIDDEN,
                             stat_labels[index_d])

            self.assertAllClose(-20.0, node_stepper.get_tensor_value("d:0"))

            output = cli.inject_value(["d:0", "20.0"])

            # Verify that the override is available.
            self.assertEqual(["d:0"], node_stepper.override_names())

            # Verify that the list of sorted nodes reflects the existence of the value
            # override (i.e., injection).
            output = cli.list_sorted_nodes([])
            node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
                output.lines)

            index_d = node_names.index("d")
            self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_CONT,
                             stat_labels[index_d])
            self.assertIn(stepper_cli.NodeStepperCLI.STATE_OVERRIDDEN,
                          stat_labels[index_d])
  def testContWithRestoreVariablesOptionShouldRestoreVariableValue(self):
    cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.opt))
    output = cli.cont(["opt/update_a/ApplyGradientDescent"])

    # After cont() call on .../update_a/..., Variable a should have been marked
    # as dirty, whereas b should not have.
    output = cli.list_sorted_nodes([])
    node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines)
    self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
                  stat_labels[node_names.index("a")])
    self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
                     stat_labels[node_names.index("b")])

    output = cli.cont(["opt/update_b/ApplyGradientDescent", "-r"])

    # After cont() call on .../update_b/... with the -r flag, Variable b should
    # have been marked as dirty, whereas Variable a should not be because it
    # should have been restored.
    output = cli.list_sorted_nodes([])
    node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines)
    self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
                  stat_labels[node_names.index("b")])
    self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
                     stat_labels[node_names.index("a")])
  def testSteppingOneStepAtATimeShouldUpdateStatus(self):
    cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))

    output = cli.list_sorted_nodes([])
    orig_node_names, _, node_pointer = _parse_sorted_nodes_list(output.lines)
    self.assertEqual(0, node_pointer)

    for i in xrange(len(orig_node_names)):
      output = cli.step([])
      node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
          output.lines)

      next_node_name = node_names[node_pointer]
      self.assertEqual(orig_node_names[i], next_node_name)

      self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT,
                    stat_labels[node_pointer])

      # The order in which the nodes are listed should not change as the
      # stepping happens.
      output = cli.list_sorted_nodes([])
      node_names, _, node_pointer = _parse_sorted_nodes_list(output.lines)
      self.assertEqual(orig_node_names, node_names)

      if i < len(orig_node_names) - 1:
        self.assertEqual(i + 1, node_pointer)
      else:
        # Stepped over the limit. Pointer should be at -1.
        self.assertEqual(-1, node_pointer)

    # Attempt to step once more after the end has been reached should error out.
    output = cli.step([])
    self.assertEqual([
        "ERROR: Cannot step any further because the end of the sorted "
        "transitive closure has been reached."
    ], output.lines)
Exemple #4
0
    def before_run(self, run_context):
        if not self._session_wrapper:
            self._session_wrapper = local_cli_wrapper.LocalCLIDebugWrapperSession(
                run_context.session,
                ui_type=self._ui_type,
                dump_root=self._dump_root,
                thread_name_filter=self._thread_name_filter)

            # Actually register tensor filters registered prior to the construction
            # of the underlying LocalCLIDebugWrapperSession object.
            for filter_name in self._pending_tensor_filters:
                self._session_wrapper.add_tensor_filter(
                    filter_name, self._pending_tensor_filters[filter_name])

        # Increment run call counter.
        self._session_wrapper.increment_run_call_count()

        # Adapt run_context to an instance of OnRunStartRequest for invoking
        # superclass on_run_start().
        on_run_start_request = framework.OnRunStartRequest(
            run_context.original_args.fetches,
            run_context.original_args.feed_dict, None, None,
            self._session_wrapper.run_call_count)

        on_run_start_response = self._session_wrapper.on_run_start(
            on_run_start_request)
        self._performed_action = on_run_start_response.action

        run_args = session_run_hook.SessionRunArgs(
            None, feed_dict=None, options=config_pb2.RunOptions())
        if self._performed_action == framework.OnRunStartAction.DEBUG_RUN:
            # pylint: disable=protected-access
            self._session_wrapper._decorate_run_options_for_debug(
                run_args.options,
                on_run_start_response.debug_urls,
                debug_ops=on_run_start_response.debug_ops,
                node_name_regex_whitelist=(
                    on_run_start_response.node_name_regex_whitelist),
                op_type_regex_whitelist=(
                    on_run_start_response.op_type_regex_whitelist),
                tensor_dtype_regex_whitelist=(
                    on_run_start_response.tensor_dtype_regex_whitelist),
                tolerate_debug_op_creation_failures=(
                    on_run_start_response.tolerate_debug_op_creation_failures))
            # pylint: enable=protected-access
        elif self._performed_action == framework.OnRunStartAction.PROFILE_RUN:
            # pylint: disable=protected-access
            self._session_wrapper._decorate_run_options_for_profile(
                run_args.options)
            # pylint: enable=protected-access
        elif self._performed_action == framework.OnRunStartAction.INVOKE_STEPPER:
            # The _finalized property must be set to False so that the NodeStepper
            # can insert ops for retrieving TensorHandles.
            # pylint: disable=protected-access
            run_context.session.graph._finalized = False
            # pylint: enable=protected-access

            with stepper.NodeStepper(
                    run_context.session, run_context.original_args.fetches,
                    run_context.original_args.feed_dict) as node_stepper:
                self._session_wrapper.invoke_node_stepper(
                    node_stepper, restore_variable_values_on_exit=True)

        return run_args
  def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
    """Wrapper around Session.run() that inserts tensor watch options.

    Args:
      fetches: Same as the `fetches` arg to regular `Session.run()`.
      feed_dict: Same as the `feed_dict` arg to regular `Session.run()`.
      options: Same as the `options` arg to regular `Session.run()`.
      run_metadata: Same as the `run_metadata` arg to regular `Session.run()`.

    Returns:
      Simply forwards the output of the wrapped `Session.run()` call.

    Raises:
      ValueError: On invalid `OnRunStartAction` value.
    """

    self._run_call_count += 1

    # Invoke on-run-start callback and obtain response.
    run_start_resp = self.on_run_start(
        OnRunStartRequest(fetches, feed_dict, options, run_metadata,
                          self._run_call_count))
    _check_type(run_start_resp, OnRunStartResponse)

    if run_start_resp.action == OnRunStartAction.DEBUG_RUN:
      # Decorate RunOption to fill in debugger tensor watch specifications.
      decorated_run_options = options or config_pb2.RunOptions()
      run_metadata = run_metadata or config_pb2.RunMetadata()

      self._decorate_run_options(
          decorated_run_options,
          run_start_resp.debug_urls,
          debug_ops=run_start_resp.debug_ops,
          node_name_regex_whitelist=run_start_resp.node_name_regex_whitelist,
          op_type_regex_whitelist=run_start_resp.op_type_regex_whitelist,
          tensor_dtype_regex_whitelist=(
              run_start_resp.tensor_dtype_regex_whitelist),
          tolerate_debug_op_creation_failures=(
              run_start_resp.tolerate_debug_op_creation_failures))

      # Invoke the run() method of the wrapped Session. Catch any TensorFlow
      # runtime errors.
      tf_error = None
      try:
        retvals = self._sess.run(fetches,
                                 feed_dict=feed_dict,
                                 options=decorated_run_options,
                                 run_metadata=run_metadata)
      except errors.OpError as op_error:
        tf_error = op_error
        retvals = op_error

      run_end_req = OnRunEndRequest(
          run_start_resp.action,
          run_metadata=run_metadata,
          client_graph_def=self._sess.graph.as_graph_def(),
          tf_error=tf_error)

    elif (run_start_resp.action == OnRunStartAction.NON_DEBUG_RUN or
          run_start_resp.action == OnRunStartAction.INVOKE_STEPPER):
      if run_start_resp.action == OnRunStartAction.INVOKE_STEPPER:
        with stepper.NodeStepper(
            self._sess, fetches, feed_dict) as node_stepper:
          retvals = self.invoke_node_stepper(
              node_stepper, restore_variable_values_on_exit=True)

      # Invoke run() method of the wrapped session.
      retvals = self._sess.run(
          fetches,
          feed_dict=feed_dict,
          options=options,
          run_metadata=run_metadata)

      # Prepare arg for the on-run-end callback.
      run_end_req = OnRunEndRequest(run_start_resp.action)
    else:
      raise ValueError(
          "Invalid OnRunStartAction value: %s" % run_start_resp.action)

    # Invoke on-run-end callback and obtain response.
    run_end_resp = self.on_run_end(run_end_req)
    _check_type(run_end_resp, OnRunEndResponse)
    # Currently run_end_resp is only a placeholder. No action is taken on it.

    return retvals
  def run(self,
          fetches,
          feed_dict=None,
          options=None,
          run_metadata=None,
          callable_runner=None,
          callable_runner_args=None):
    """Wrapper around Session.run() that inserts tensor watch options.

    Args:
      fetches: Same as the `fetches` arg to regular `Session.run()`.
      feed_dict: Same as the `feed_dict` arg to regular `Session.run()`.
      options: Same as the `options` arg to regular `Session.run()`.
      run_metadata: Same as the `run_metadata` arg to regular `Session.run()`.
      callable_runner: A `callable` returned by `Session.make_callable()`.
        If not `None`, `fetches` and `feed_dict` must both be `None`.
      callable_runner_args: An optional list of arguments to `callable_runner`.

    Returns:
      Simply forwards the output of the wrapped `Session.run()` call.

    Raises:
      ValueError: On invalid `OnRunStartAction` value. Or if `callable_runner`
        is not `None` and either or both of `fetches` and `feed_dict` is `None`.
    """
    if not callable_runner:
      self.increment_run_call_count()
    else:
      if fetches or feed_dict:
        raise ValueError(
            "callable_runner and fetches/feed_dict are mutually exclusive, but "
            "are used simultaneously.")

    empty_fetches = not nest.flatten(fetches)
    if empty_fetches:
      tf_logging.info(
          "Due to empty fetches, tfdbg Session wrapper is letting a "
          "Session.run pass through without any debugging actions.")
    if self._is_disabled_thread() or empty_fetches:
      if callable_runner:
        return callable_runner(*callable_runner_args)
      else:
        return self._sess.run(fetches,
                              feed_dict=feed_dict,
                              options=options,
                              run_metadata=run_metadata)

    # Invoke on-run-start callback and obtain response.
    run_start_resp = self.on_run_start(
        OnRunStartRequest(fetches, feed_dict, options, run_metadata,
                          self._run_call_count,
                          is_callable_runner=bool(callable_runner)))
    _check_type(run_start_resp, OnRunStartResponse)

    if run_start_resp.action == OnRunStartAction.DEBUG_RUN:
      # Decorate RunOption to fill in debugger tensor watch specifications.
      decorated_run_options = options or config_pb2.RunOptions()
      run_metadata = run_metadata or config_pb2.RunMetadata()

      self._decorate_run_options_for_debug(
          decorated_run_options,
          run_start_resp.debug_urls,
          debug_ops=run_start_resp.debug_ops,
          node_name_regex_whitelist=run_start_resp.node_name_regex_whitelist,
          op_type_regex_whitelist=run_start_resp.op_type_regex_whitelist,
          tensor_dtype_regex_whitelist=(
              run_start_resp.tensor_dtype_regex_whitelist),
          tolerate_debug_op_creation_failures=(
              run_start_resp.tolerate_debug_op_creation_failures))

      # Invoke the run() method of the wrapped Session. Catch any TensorFlow
      # runtime errors.
      tf_error = None
      try:
        if callable_runner:
          retvals = callable_runner(*callable_runner_args,
                                    options=decorated_run_options,
                                    run_metadata=run_metadata)
        else:
          retvals = self._sess.run(fetches,
                                   feed_dict=feed_dict,
                                   options=decorated_run_options,
                                   run_metadata=run_metadata)
      except errors.OpError as op_error:
        if self._pass_through_operrors:
          raise op_error
        tf_error = op_error
        retvals = op_error

      run_end_req = OnRunEndRequest(
          run_start_resp.action,
          run_metadata=run_metadata,
          client_graph_def=self._sess.graph.as_graph_def(),
          tf_error=tf_error)

    elif run_start_resp.action == OnRunStartAction.PROFILE_RUN:
      decorated_run_options = options or config_pb2.RunOptions()
      run_metadata = run_metadata or config_pb2.RunMetadata()
      self._decorate_run_options_for_profile(decorated_run_options)
      if callable_runner:
        retvals = callable_runner(*callable_runner_args,
                                  options=decorated_run_options,
                                  run_metadata=run_metadata)
      else:
        retvals = self._sess.run(fetches,
                                 feed_dict=feed_dict,
                                 options=decorated_run_options,
                                 run_metadata=run_metadata)
      run_end_req = OnRunEndRequest(
          run_start_resp.action,
          run_metadata=run_metadata,
          client_graph_def=self._sess.graph.as_graph_def())
    elif (run_start_resp.action == OnRunStartAction.NON_DEBUG_RUN or
          run_start_resp.action == OnRunStartAction.INVOKE_STEPPER):
      if callable_runner:
        raise NotImplementedError(
            "Stepper mode is not implemented for callables created by "
            "Session.make_callable().")

      if run_start_resp.action == OnRunStartAction.INVOKE_STEPPER:
        with stepper.NodeStepper(
            self._sess, fetches, feed_dict) as node_stepper:
          retvals = self.invoke_node_stepper(
              node_stepper, restore_variable_values_on_exit=True)

      # Invoke run() method of the wrapped session.
      retvals = self._sess.run(
          fetches,
          feed_dict=feed_dict,
          options=options,
          run_metadata=run_metadata)

      # Prepare arg for the on-run-end callback.
      run_end_req = OnRunEndRequest(run_start_resp.action)
    else:
      raise ValueError(
          "Invalid OnRunStartAction value: %s" % run_start_resp.action)

    # Invoke on-run-end callback and obtain response.
    run_end_resp = self.on_run_end(run_end_req)
    _check_type(run_end_resp, OnRunEndResponse)
    # Currently run_end_resp is only a placeholder. No action is taken on it.

    return retvals
Exemple #7
0
  def testInjectTensorValueByNodeNameShouldBeReflected(self):
    with stepper.NodeStepper(self.sess, self.e) as node_stepper:
      cli = stepper_cli.NodeStepperCLI(node_stepper)

      cli.inject_value(["d", "20.0"])
      self.assertEqual(["d:0"], node_stepper.override_names())