示例#1
0
  def node_info(self, args, screen_info=None):
    """Command handler for node_info.

    Query information about a given node.

    Args:
      args: Command-line arguments, excluding the command prefix, as a list of
        str.
      screen_info: Optional dict input containing screen information such as
        cols.

    Returns:
      Output text lines as a RichTextLines object.
    """

    # TODO(cais): Add annotation of substrings for node names, to facilitate
    # on-screen highlighting/selection of node names.
    _ = screen_info

    parsed = self._arg_parsers["node_info"].parse_args(args)

    # Get a node name, regardless of whether the input is a node name (without
    # output slot attached) or a tensor name (with output slot attached).
    node_name, unused_slot = self._parse_node_or_tensor_name(parsed.node_name)

    if not self._debug_dump.node_exists(node_name):
      return self._error(
          "There is no node named \"%s\" in the partition graphs" % node_name)

    # TODO(cais): Provide UI glossary feature to explain to users what the
    # term "partition graph" means and how it is related to TF graph objects
    # in Python. The information can be along the line of:
    # "A tensorflow graph defined in Python is stripped of unused ops
    # according to the feeds and fetches and divided into a number of
    # partition graphs that may be distributed among multiple devices and
    # hosts. The partition graphs are what's actually executed by the C++
    # runtime during a run() call."

    lines = ["Node %s" % node_name]
    lines.append("")
    lines.append("  Op: %s" % self._debug_dump.node_op_type(node_name))
    lines.append("  Device: %s" % self._debug_dump.node_device(node_name))

    # List node inputs (non-control and control).
    inputs = self._debug_dump.node_inputs(node_name)
    ctrl_inputs = self._debug_dump.node_inputs(node_name, is_control=True)

    input_lines = self._format_neighbors("input", inputs, ctrl_inputs)
    lines.extend(input_lines)

    # List node output recipients (non-control and control).
    recs = self._debug_dump.node_recipients(node_name)
    ctrl_recs = self._debug_dump.node_recipients(node_name, is_control=True)

    rec_lines = self._format_neighbors("recipient", recs, ctrl_recs)
    lines.extend(rec_lines)

    # Optional: List attributes of the node.
    if parsed.attributes:
      lines.extend(self._list_node_attributes(node_name))

    # Optional: List dumps available from the node.
    if parsed.dumps:
      lines.extend(self._list_node_dumps(node_name))

    return debugger_cli_common.RichTextLines(lines)
示例#2
0
  def _list_inputs_or_outputs(self,
                              recursive,
                              node_name,
                              depth,
                              control,
                              op_type,
                              do_outputs=False):
    """Helper function used by list_inputs and list_outputs.

    Format a list of lines to display the inputs or output recipients of a
    given node.

    Args:
      recursive: Whether the listing is to be done recursively, as a boolean.
      node_name: The name of the node in question, as a str.
      depth: Maximum recursion depth, applies only if recursive == True, as an
        int.
      control: Whether control inputs or control recipients are included, as a
        boolean.
      op_type: Whether the op types of the nodes are to be included, as a
        boolean.
      do_outputs: Whether recipients, instead of input nodes are to be
        listed, as a boolean.

    Returns:
      Input or recipient tree formatted as a RichTextLines object.
    """

    if do_outputs:
      tracker = self._debug_dump.node_recipients
      type_str = "Recipients of"
      short_type_str = "recipients"
    else:
      tracker = self._debug_dump.node_inputs
      type_str = "Inputs to"
      short_type_str = "inputs"

    lines = []

    # Check if this is a tensor name, instead of a node name.
    node_name, _ = self._parse_node_or_tensor_name(node_name)

    # Check if node exists.
    if not self._debug_dump.node_exists(node_name):
      return self._error(
          "There is no node named \"%s\" in the partition graphs" % node_name)

    if recursive:
      max_depth = depth
    else:
      max_depth = 1

    if control:
      include_ctrls_str = ", control %s included" % short_type_str
    else:
      include_ctrls_str = ""

    lines.append("%s node \"%s\" (Depth limit = %d%s):" %
                 (type_str, node_name, max_depth, include_ctrls_str))

    self._dfs_from_node(lines, node_name, tracker, max_depth, 1, [], control,
                        op_type)

    # Include legend.
    lines.append("")
    lines.append("Legend:")
    lines.append("  (d): recursion depth = d.")

    if control:
      lines.append("  (Ctrl): Control input.")
    if op_type:
      lines.append("  [Op]: Input node has op type Op.")

    return debugger_cli_common.RichTextLines(lines)
    def testWrappingEmptyInput(self):
        out, new_line_indices = debugger_cli_common.wrap_rich_text_lines(
            debugger_cli_common.RichTextLines([]), 10)

        self.assertEqual([], out.lines)
        self.assertEqual([], new_line_indices)
示例#4
0
 def _error(self, msg):
   return debugger_cli_common.RichTextLines(
       ["ERROR: " + msg], font_attr_segs={0: [(0, len(msg), "red")]})
 def setUp(self):
     self._orig_screen_output = debugger_cli_common.RichTextLines(
         ["Roses are red", "Violets are blue"])
 def testRichTextLinesConstructorWithInvalidType(self):
     with self.assertRaisesRegexp(ValueError, "Unexpected type in lines"):
         debugger_cli_common.RichTextLines(123)
 def _noop_handler(self, argv, screen_info=None):
     # A handler that does nothing other than returning "Done."
     return debugger_cli_common.RichTextLines(["Done."])
 def _echo_screen_cols(self, argv, screen_info=None):
     # A handler that uses screen_info.
     return debugger_cli_common.RichTextLines(
         ["cols = %d" % screen_info["cols"]])
示例#9
0
def get_error_intro(tf_error):
    """Generate formatted intro for TensorFlow run-time error.

  Args:
    tf_error: (errors.OpError) TensorFlow run-time error object.

  Returns:
    (RichTextLines) Formatted intro message about the run-time OpError, with
      sample commands for debugging.
  """

    if hasattr(tf_error, "op") and hasattr(tf_error.op, "name"):
        op_name = tf_error.op.name
    else:
        op_name = None

    intro_lines = [
        "--------------------------------------",
        RL("!!! An error occurred during the run !!!", "blink"),
        "",
    ]

    out = debugger_cli_common.rich_text_lines_from_rich_line_list(intro_lines)

    if op_name is not None:
        out.extend(
            debugger_cli_common.RichTextLines(
                ["You may use the following commands to debug:"]))
        out.extend(
            _recommend_command("ni -a -d -t %s" % op_name,
                               "Inspect information about the failing op.",
                               create_link=True))
        out.extend(
            _recommend_command("li -r %s" % op_name,
                               "List inputs to the failing op, recursively.",
                               create_link=True))

        out.extend(
            _recommend_command(
                "lt",
                "List all tensors dumped during the failing run() call.",
                create_link=True))
    else:
        out.extend(
            debugger_cli_common.RichTextLines([
                "WARNING: Cannot determine the name of the op that caused the error."
            ]))

    more_lines = [
        "",
        "Op name:    %s" % op_name,
        "Error type: " + str(type(tf_error)),
        "",
        "Details:",
        str(tf_error),
        "",
        "--------------------------------------",
        "",
    ]

    out.extend(debugger_cli_common.RichTextLines(more_lines))

    return out
示例#10
0
def get_run_start_intro(run_call_count,
                        fetches,
                        feed_dict,
                        tensor_filters):
  """Generate formatted intro for run-start UI.

  Args:
    run_call_count: (int) Run call counter.
    fetches: Fetches of the `Session.run()` call. See doc of `Session.run()`
      for more details.
    feed_dict: Feeds to the `Session.run()` call. See doc of `Session.run()`
      for more details.
    tensor_filters: (dict) A dict from tensor-filter name to tensor-filter
      callable.

  Returns:
    (RichTextLines) Formatted intro message about the `Session.run()` call.
  """

  fetch_lines = _get_fetch_names(fetches)

  if not feed_dict:
    feed_dict_lines = ["(Empty)"]
  else:
    feed_dict_lines = []
    for feed_key in feed_dict:
      if isinstance(feed_key, six.string_types):
        feed_dict_lines.append(feed_key)
      else:
        feed_dict_lines.append(feed_key.name)

  intro_lines = [
      "======================================",
      "Session.run() call #%d:" % run_call_count,
      "", "Fetch(es):"
  ]
  intro_lines.extend(["  " + line for line in fetch_lines])
  intro_lines.extend(["", "Feed dict(s):"])
  intro_lines.extend(["  " + line for line in feed_dict_lines])
  intro_lines.extend([
      "======================================", "",
      "Select one of the following commands to proceed ---->"
  ])

  out = debugger_cli_common.RichTextLines(intro_lines)

  out.extend(
      _recommend_command(
          "run",
          "Execute the run() call with debug tensor-watching",
          create_link=True))
  out.extend(
      _recommend_command(
          "run -n",
          "Execute the run() call without debug tensor-watching",
          create_link=True))
  out.extend(
      _recommend_command(
          "run -t <T>",
          "Execute run() calls (T - 1) times without debugging, then "
          "execute run() one more time and drop back to the CLI"))
  out.extend(
      _recommend_command(
          "run -f <filter_name>",
          "Keep executing run() calls until a dumped tensor passes a given, "
          "registered filter (conditional breakpoint mode)"))

  more_font_attr_segs = {}
  more_lines = ["    Registered filter(s):"]
  if tensor_filters:
    filter_names = []
    for filter_name in tensor_filters:
      filter_names.append(filter_name)
      more_lines.append("        * " + filter_name)
      command_menu_node = debugger_cli_common.MenuItem(
          "", "run -f %s" % filter_name)
      more_font_attr_segs[len(more_lines) - 1] = [
          (10, len(more_lines[-1]), command_menu_node)]
  else:
    more_lines.append("        (None)")

  out.extend(
      debugger_cli_common.RichTextLines(
          more_lines, font_attr_segs=more_font_attr_segs))

  out.extend(
      _recommend_command(
          "invoke_stepper",
          "Use the node-stepper interface, which allows you to interactively "
          "step through nodes involved in the graph run() call and "
          "inspect/modify their values", create_link=True))

  out.extend(debugger_cli_common.RichTextLines([
      "",
      "For more details, see help below:"
      "",
  ]))

  # Make main menu for the run-start intro.
  menu = debugger_cli_common.Menu()
  menu.append(debugger_cli_common.MenuItem("run", "run"))
  menu.append(debugger_cli_common.MenuItem(
      "invoke_stepper", "invoke_stepper"))
  menu.append(debugger_cli_common.MenuItem("exit", "exit"))
  out.annotations[debugger_cli_common.MAIN_MENU_KEY] = menu

  return out
示例#11
0
def get_run_start_intro(run_call_count,
                        fetches,
                        feed_dict,
                        tensor_filters,
                        is_callable_runner=False):
    """Generate formatted intro for run-start UI.

  Args:
    run_call_count: (int) Run call counter.
    fetches: Fetches of the `Session.run()` call. See doc of `Session.run()`
      for more details.
    feed_dict: Feeds to the `Session.run()` call. See doc of `Session.run()`
      for more details.
    tensor_filters: (dict) A dict from tensor-filter name to tensor-filter
      callable.
    is_callable_runner: (bool) whether a runner returned by
        Session.make_callable is being run.

  Returns:
    (RichTextLines) Formatted intro message about the `Session.run()` call.
  """

    fetch_lines = common.get_flattened_names(fetches)

    if not feed_dict:
        feed_dict_lines = [debugger_cli_common.RichLine("  (Empty)")]
    else:
        feed_dict_lines = []
        for feed_key in feed_dict:
            feed_key_name = common.get_graph_element_name(feed_key)
            feed_dict_line = debugger_cli_common.RichLine("  ")
            feed_dict_line += debugger_cli_common.RichLine(
                feed_key_name,
                debugger_cli_common.MenuItem(None, "pf '%s'" % feed_key_name))
            # Surround the name string with quotes, because feed_key_name may contain
            # spaces in some cases, e.g., SparseTensors.
            feed_dict_lines.append(feed_dict_line)
    feed_dict_lines = debugger_cli_common.rich_text_lines_from_rich_line_list(
        feed_dict_lines)

    out = debugger_cli_common.RichTextLines(_HORIZONTAL_BAR)
    if is_callable_runner:
        out.append("Running a runner returned by Session.make_callable()")
    else:
        out.append("Session.run() call #%d:" % run_call_count)
        out.append("")
        out.append("Fetch(es):")
        out.extend(
            debugger_cli_common.RichTextLines(
                ["  " + line for line in fetch_lines]))
        out.append("")
        out.append("Feed dict:")
        out.extend(feed_dict_lines)
    out.append(_HORIZONTAL_BAR)
    out.append("")
    out.append("Select one of the following commands to proceed ---->")

    out.extend(
        _recommend_command("run",
                           "Execute the run() call with debug tensor-watching",
                           create_link=True))
    out.extend(
        _recommend_command(
            "run -n",
            "Execute the run() call without debug tensor-watching",
            create_link=True))
    out.extend(
        _recommend_command(
            "run -t <T>",
            "Execute run() calls (T - 1) times without debugging, then "
            "execute run() once more with debugging and drop back to the CLI"))
    out.extend(
        _recommend_command(
            "run -f <filter_name>",
            "Keep executing run() calls until a dumped tensor passes a given, "
            "registered filter (conditional breakpoint mode)"))

    more_lines = ["    Registered filter(s):"]
    if tensor_filters:
        filter_names = []
        for filter_name in tensor_filters:
            filter_names.append(filter_name)
            command_menu_node = debugger_cli_common.MenuItem(
                "", "run -f %s" % filter_name)
            more_lines.append(
                RL("        * ") + RL(filter_name, command_menu_node))
    else:
        more_lines.append("        (None)")

    out.extend(
        debugger_cli_common.rich_text_lines_from_rich_line_list(more_lines))

    out.extend(
        _recommend_command(
            "invoke_stepper",
            "Use the node-stepper interface, which allows you to interactively "
            "step through nodes involved in the graph run() call and "
            "inspect/modify their values",
            create_link=True))

    out.append("")

    out.append_rich_line(
        RL("For more details, see ") +
        RL("help.", debugger_cli_common.MenuItem("", "help")) + ".")
    out.append("")

    # Make main menu for the run-start intro.
    menu = debugger_cli_common.Menu()
    menu.append(debugger_cli_common.MenuItem("run", "run"))
    menu.append(
        debugger_cli_common.MenuItem("invoke_stepper", "invoke_stepper"))
    menu.append(debugger_cli_common.MenuItem("exit", "exit"))
    out.annotations[debugger_cli_common.MAIN_MENU_KEY] = menu

    return out
示例#12
0
    def list_sorted_nodes(self, args, screen_info=None):
        """List the sorted transitive closure of the stepper's fetches."""

        # TODO(cais): Use pattern such as del args, del screen_info python/debug.
        _ = args
        _ = screen_info

        parsed = self.arg_parsers["list_sorted_nodes"].parse_args(args)

        if parsed.lower_bound != -1 and parsed.upper_bound != -1:
            index_range = [
                max(0, parsed.lower_bound),
                min(len(self._sorted_nodes), parsed.upper_bound)
            ]
            verbose = False
        else:
            index_range = [0, len(self._sorted_nodes)]
            verbose = True

        handle_node_names = self._node_stepper.handle_node_names()
        intermediate_tensor_names = self._node_stepper.intermediate_tensor_names(
        )
        override_names = self._node_stepper.override_names()
        dirty_variable_names = [
            dirty_variable.split(":")[0]
            for dirty_variable in self._node_stepper.dirty_variables()
        ]

        lines = []
        font_attr_segs = {}
        if verbose:
            lines.extend([
                "Topologically-sorted transitive input(s) and fetch(es):", ""
            ])

        line_counter = len(lines)
        for i, element_name in enumerate(self._sorted_nodes):
            if i < index_range[0] or i >= index_range[1]:
                continue

            font_attr_segs[line_counter] = []

            # TODO(cais): Use fixed-width text to show node index.
            node_prefix = "(%d / %d)" % (i + 1, len(self._sorted_nodes))
            if i == self._next:
                node_prefix = "  " + self.NEXT_NODE_POINTER_STR + node_prefix
                font_attr_segs[line_counter].append((0, 3, "bold"))
            else:
                node_prefix = "     " + node_prefix

            node_prefix += "  ["
            labels, label_font_attr_segs = self._get_status_labels(
                element_name, handle_node_names, intermediate_tensor_names,
                override_names, dirty_variable_names, len(node_prefix))
            node_prefix += labels
            font_attr_segs[line_counter].extend(label_font_attr_segs)

            lines.append(node_prefix + "] " + element_name)
            line_counter += 1

        output = debugger_cli_common.RichTextLines(
            lines, font_attr_segs=font_attr_segs)

        if verbose:
            output.extend(self._node_status_label_legend())

        return output
示例#13
0
  def print_source(self, args, screen_info=None):
    """Print a Python source file with line-level profile information.

    Args:
      args: Command-line arguments, excluding the command prefix, as a list of
        str.
      screen_info: Optional dict input containing screen information such as
        cols.

    Returns:
      Output text lines as a RichTextLines object.
    """
    del screen_info

    parsed = self._arg_parsers["print_source"].parse_args(args)

    device_name_regex = (re.compile(parsed.device_name_filter)
                         if parsed.device_name_filter else None)

    profile_data = []
    data_generator = self._get_profile_data_generator()
    device_count = len(self._run_metadata.step_stats.dev_stats)
    for index in range(device_count):
      device_stats = self._run_metadata.step_stats.dev_stats[index]
      if device_name_regex and not device_name_regex.match(device_stats.device):
        continue
      profile_data.extend([datum for datum in data_generator(device_stats)])

    source_annotation = source_utils.annotate_source_against_profile(
        profile_data,
        os.path.expanduser(parsed.source_file_path),
        node_name_filter=parsed.node_name_filter,
        op_type_filter=parsed.op_type_filter)
    if not source_annotation:
      return debugger_cli_common.RichTextLines(
          ["The source file %s does not contain any profile information for "
           "the previous Session run under the following "
           "filters:" % parsed.source_file_path,
           "  --%s: %s" % (_DEVICE_NAME_FILTER_FLAG, parsed.device_name_filter),
           "  --%s: %s" % (_NODE_NAME_FILTER_FLAG, parsed.node_name_filter),
           "  --%s: %s" % (_OP_TYPE_FILTER_FLAG, parsed.op_type_filter)])

    max_total_cost = 0
    for line_index in source_annotation:
      total_cost = self._get_total_cost(source_annotation[line_index],
                                        parsed.cost_type)
      max_total_cost = max(max_total_cost, total_cost)

    source_lines, line_num_width = source_utils.load_source(
        parsed.source_file_path)

    cost_bar_max_length = 10
    total_cost_head = parsed.cost_type
    column_widths = {
        "cost_bar": cost_bar_max_length + 3,
        "total_cost": len(total_cost_head) + 3,
        "num_nodes_execs": len(self._NUM_EXECS_SUB_HEAD) + 1,
        "line_number": line_num_width,
    }

    head = RL(
        " " * column_widths["cost_bar"] +
        total_cost_head +
        " " * (column_widths["total_cost"] - len(total_cost_head)) +
        self._NUM_NODES_HEAD +
        " " * (column_widths["num_nodes_execs"] - len(self._NUM_NODES_HEAD)),
        font_attr=self._LINE_COST_ATTR)
    head += RL(self._LINENO_HEAD, font_attr=self._LINE_NUM_ATTR)
    sub_head = RL(
        " " * (column_widths["cost_bar"] +
               column_widths["total_cost"]) +
        self._NUM_EXECS_SUB_HEAD +
        " " * (column_widths["num_nodes_execs"] -
               len(self._NUM_EXECS_SUB_HEAD)) +
        " " * column_widths["line_number"],
        font_attr=self._LINE_COST_ATTR)
    sub_head += RL(self._SOURCE_HEAD, font_attr="bold")
    lines = [head, sub_head]

    output_annotations = {}
    for i, line in enumerate(source_lines):
      lineno = i + 1
      if lineno in source_annotation:
        annotation = source_annotation[lineno]
        cost_bar = self._render_normalized_cost_bar(
            self._get_total_cost(annotation, parsed.cost_type), max_total_cost,
            cost_bar_max_length)
        annotated_line = cost_bar
        annotated_line += " " * (column_widths["cost_bar"] - len(cost_bar))

        total_cost = RL(cli_shared.time_to_readable_str(
            self._get_total_cost(annotation, parsed.cost_type),
            force_time_unit=parsed.time_unit),
                        font_attr=self._LINE_COST_ATTR)
        total_cost += " " * (column_widths["total_cost"] - len(total_cost))
        annotated_line += total_cost

        file_path_filter = re.escape(parsed.source_file_path) + "$"
        command = "lp --file_path_filter %s --min_lineno %d --max_lineno %d" % (
            file_path_filter, lineno, lineno + 1)
        if parsed.device_name_filter:
          command += " --%s %s" % (_DEVICE_NAME_FILTER_FLAG,
                                   parsed.device_name_filter)
        if parsed.node_name_filter:
          command += " --%s %s" % (_NODE_NAME_FILTER_FLAG,
                                   parsed.node_name_filter)
        if parsed.op_type_filter:
          command += " --%s %s" % (_OP_TYPE_FILTER_FLAG,
                                   parsed.op_type_filter)
        menu_item = debugger_cli_common.MenuItem(None, command)
        num_nodes_execs = RL("%d(%d)" % (annotation.node_count,
                                         annotation.node_exec_count),
                             font_attr=[self._LINE_COST_ATTR, menu_item])
        num_nodes_execs += " " * (
            column_widths["num_nodes_execs"] - len(num_nodes_execs))
        annotated_line += num_nodes_execs
      else:
        annotated_line = RL(
            " " * sum(column_widths[col_name] for col_name in column_widths
                      if col_name != "line_number"))

      line_num_column = RL(" L%d" % (lineno), self._LINE_NUM_ATTR)
      line_num_column += " " * (
          column_widths["line_number"] - len(line_num_column))
      annotated_line += line_num_column
      annotated_line += line
      lines.append(annotated_line)

      if parsed.init_line == lineno:
        output_annotations[
            debugger_cli_common.INIT_SCROLL_POS_KEY] = len(lines) - 1

    return debugger_cli_common.rich_text_lines_from_rich_line_list(
        lines, annotations=output_annotations)
示例#14
0
def format_tensor(tensor,
                  tensor_label,
                  include_metadata=False,
                  auxiliary_message=None,
                  include_numeric_summary=False,
                  np_printoptions=None,
                  highlight_options=None):
    """Generate a RichTextLines object showing a tensor in formatted style.

  Args:
    tensor: The tensor to be displayed, as a numpy ndarray or other
      appropriate format (e.g., None representing uninitialized tensors).
    tensor_label: A label for the tensor, as a string. If set to None, will
      suppress the tensor name line in the return value.
    include_metadata: Whether metadata such as dtype and shape are to be
      included in the formatted text.
    auxiliary_message: An auxiliary message to display under the tensor label,
      dtype and shape information lines.
    include_numeric_summary: Whether a text summary of the numeric values (if
      applicable) will be included.
    np_printoptions: A dictionary of keyword arguments that are passed to a
      call of np.set_printoptions() to set the text format for display numpy
      ndarrays.
    highlight_options: (HighlightOptions) options for highlighting elements
      of the tensor.

  Returns:
    A RichTextLines object. Its annotation field has line-by-line markups to
    indicate which indices in the array the first element of each line
    corresponds to.
  """
    lines = []
    font_attr_segs = {}

    if tensor_label is not None:
        lines.append("Tensor \"%s\":" % tensor_label)
        suffix = tensor_label.split(":")[-1]
        if suffix.isdigit():
            # Suffix is a number. Assume it is the output slot index.
            font_attr_segs[0] = [(8, 8 + len(tensor_label), "bold")]
        else:
            # Suffix is not a number. It is auxiliary information such as the debug
            # op type. In this case, highlight the suffix with a different color.
            debug_op_len = len(suffix)
            proper_len = len(tensor_label) - debug_op_len - 1
            font_attr_segs[0] = [(8, 8 + proper_len, "bold"),
                                 (8 + proper_len + 1,
                                  8 + proper_len + 1 + debug_op_len, "yellow")]

    if isinstance(tensor, debug_data.InconvertibleTensorProto):
        if lines:
            lines.append("")
        lines.extend(str(tensor).split("\n"))
        return debugger_cli_common.RichTextLines(lines)
    elif not isinstance(tensor, np.ndarray):
        # If tensor is not a np.ndarray, return simple text-line representation of
        # the object without annotations.
        if lines:
            lines.append("")
        lines.extend(repr(tensor).split("\n"))
        return debugger_cli_common.RichTextLines(lines)

    if include_metadata:
        lines.append("  dtype: %s" % str(tensor.dtype))
        lines.append("  shape: %s" % str(tensor.shape))

    if lines:
        lines.append("")
    formatted = debugger_cli_common.RichTextLines(
        lines, font_attr_segs=font_attr_segs)

    if auxiliary_message:
        formatted.extend(auxiliary_message)

    if include_numeric_summary:
        formatted.append("Numeric summary:")
        formatted.extend(numeric_summary(tensor))
        formatted.append("")

    # Apply custom string formatting options for numpy ndarray.
    if np_printoptions is not None:
        np.set_printoptions(**np_printoptions)

    array_lines = repr(tensor).split("\n")
    if tensor.dtype.type is not np.string_:
        # Parse array lines to get beginning indices for each line.

        # TODO (cais): Currently, we do not annotate string-type tensors due to id:2690 gh:2691
        #   difficulty in escaping sequences. Address this issue.
        annotations = _annotate_ndarray_lines(array_lines,
                                              tensor,
                                              np_printoptions=np_printoptions)
    else:
        annotations = None
    formatted_array = debugger_cli_common.RichTextLines(
        array_lines, annotations=annotations)
    formatted.extend(formatted_array)

    # Perform optional highlighting.
    if highlight_options is not None:
        indices_list = list(np.argwhere(highlight_options.criterion(tensor)))

        total_elements = np.size(tensor)
        highlight_summary = "Highlighted%s: %d of %d element(s) (%.2f%%)" % (
            "(%s)" % highlight_options.description
            if highlight_options.description else "", len(indices_list),
            total_elements, len(indices_list) / float(total_elements) * 100.0)

        formatted.lines[0] += " " + highlight_summary

        if indices_list:
            indices_list = [list(indices) for indices in indices_list]

            are_omitted, rows, start_cols, end_cols = locate_tensor_element(
                formatted, indices_list)
            for is_omitted, row, start_col, end_col in zip(
                    are_omitted, rows, start_cols, end_cols):
                if is_omitted or start_col is None or end_col is None:
                    continue

                if row in formatted.font_attr_segs:
                    formatted.font_attr_segs[row].append(
                        (start_col, end_col, highlight_options.font_attr))
                else:
                    formatted.font_attr_segs[row] = [
                        (start_col, end_col, highlight_options.font_attr)
                    ]

    return formatted
示例#15
0
def numeric_summary(tensor):
    """Get a text summary of a numeric tensor.

  This summary is only available for numeric (int*, float*, complex*) and
  Boolean tensors.

  Args:
    tensor: (`numpy.ndarray`) the tensor value object to be summarized.

  Returns:
    The summary text as a `RichTextLines` object. If the type of `tensor` is not
    numeric or Boolean, a single-line `RichTextLines` object containing a
    warning message will reflect that.
  """
    def _counts_summary(counts, skip_zeros=True, total_count=None):
        """Format values as a two-row table."""
        if skip_zeros:
            counts = [(count_key, count_val) for count_key, count_val in counts
                      if count_val]
        max_common_len = 0
        for count_key, count_val in counts:
            count_val_str = str(count_val)
            common_len = max(len(count_key) + 1, len(count_val_str) + 1)
            max_common_len = max(common_len, max_common_len)

        key_line = debugger_cli_common.RichLine("|")
        val_line = debugger_cli_common.RichLine("|")
        for count_key, count_val in counts:
            count_val_str = str(count_val)
            key_line += _pad_string_to_length(count_key, max_common_len)
            val_line += _pad_string_to_length(count_val_str, max_common_len)
        key_line += " |"
        val_line += " |"

        if total_count is not None:
            total_key_str = "total"
            total_val_str = str(total_count)
            max_common_len = max(len(total_key_str) + 1, len(total_val_str))
            total_key_str = _pad_string_to_length(total_key_str,
                                                  max_common_len)
            total_val_str = _pad_string_to_length(total_val_str,
                                                  max_common_len)
            key_line += total_key_str + " |"
            val_line += total_val_str + " |"

        return debugger_cli_common.rich_text_lines_from_rich_line_list(
            [key_line, val_line])

    if not isinstance(tensor, np.ndarray) or not np.size(tensor):
        return debugger_cli_common.RichTextLines(
            ["No numeric summary available due to empty tensor."])
    elif (np.issubdtype(tensor.dtype, np.float)
          or np.issubdtype(tensor.dtype, np.complex)
          or np.issubdtype(tensor.dtype, np.integer)):
        counts = [("nan", np.sum(np.isnan(tensor))),
                  ("-inf", np.sum(np.isneginf(tensor))),
                  ("-",
                   np.sum(
                       np.logical_and(tensor < 0.0,
                                      np.logical_not(np.isneginf(tensor))))),
                  ("0", np.sum(tensor == 0.0)),
                  ("+",
                   np.sum(
                       np.logical_and(tensor > 0.0,
                                      np.logical_not(np.isposinf(tensor))))),
                  ("+inf", np.sum(np.isposinf(tensor)))]
        output = _counts_summary(counts, total_count=np.size(tensor))

        valid_array = tensor[np.logical_not(
            np.logical_or(np.isinf(tensor), np.isnan(tensor)))]
        if np.size(valid_array):
            stats = [("min", np.min(valid_array)),
                     ("max", np.max(valid_array)),
                     ("mean", np.mean(valid_array)),
                     ("std", np.std(valid_array))]
            output.extend(_counts_summary(stats, skip_zeros=False))
        return output
    elif tensor.dtype == np.bool:
        counts = [
            ("False", np.sum(tensor == 0)),
            ("True", np.sum(tensor > 0)),
        ]
        return _counts_summary(counts, total_count=np.size(tensor))
    else:
        return debugger_cli_common.RichTextLines([
            "No numeric summary available due to tensor dtype: %s." %
            tensor.dtype
        ])