예제 #1
0
def _recommend_command(command, description, indent=2, create_link=False):
    """Generate a RichTextLines object that describes a recommended command.

    Args:
      command: (str) The command to recommend.
      description: (str) A description of what the command does.
      indent: (int) How many spaces to indent in the beginning.
      create_link: (bool) Whether a command link is to be applied to the command
        string.

    Returns:
      (RichTextLines) Formatted text (with font attributes) for recommending the
        command.
    """

    indent_str = " " * indent

    if create_link:
        font_attr = [debugger_cli_common.MenuItem("", command), "bold"]
    else:
        font_attr = "bold"

    lines = [
        RL(indent_str) + RL(command, font_attr) + ":",
        indent_str + "   " + description
    ]

    return debugger_cli_common.rich_text_lines_from_rich_line_list(lines)
예제 #2
0
    def render(self,
               max_length,
               backward_command,
               forward_command,
               latest_command_attribute="black_on_white",
               old_command_attribute="magenta_on_white"):
        """Render the rich text content of the single-line navigation bar.

        Args:
          max_length: (`int`) Maximum length of the navigation bar, in characters.
          backward_command: (`str`) command for going backward. Used to construct
            the shortcut menu item.
          forward_command: (`str`) command for going forward. Used to construct the
            shortcut menu item.
           latest_command_attribute: font attribute for lastest command.
           old_command_attribute: font attribute for old (non-latest) command.

        Returns:
          (`debugger_cli_common.RichTextLines`) the navigation bar text with
            attributes.

        """
        output = RL("| ")
        output += RL(self.BACK_ARROW_TEXT, (debugger_cli_common.MenuItem(
            None, backward_command) if self.can_go_back() else None))
        output += RL(" ")
        output += RL(self.FORWARD_ARROW_TEXT, (debugger_cli_common.MenuItem(
            None, forward_command) if self.can_go_forward() else None))

        if self._items:
            command_attribute = (latest_command_attribute if
                                 (self._pointer == (len(self._items) - 1)) else
                                 old_command_attribute)
            output += RL(" | ")
            if self._pointer != len(self._items) - 1:
                output += RL("(-%d) " % (len(self._items) - 1 - self._pointer),
                             command_attribute)

            if len(output) < max_length:
                maybe_truncated_command = self._items[self._pointer].command[:(
                    max_length - len(output))]
                output += RL(maybe_truncated_command, command_attribute)

        return debugger_cli_common.rich_text_lines_from_rich_line_list(
            [output])
예제 #3
0
    def value(self,
              row,
              col,
              device_name_filter=None,
              node_name_filter=None,
              op_type_filter=None):
        """Get the content of a cell of the table.

        Args:
          row: (int) row index.
          col: (int) column index.
          device_name_filter: Regular expression to filter by device name.
          node_name_filter: Regular expression to filter by node name.
          op_type_filter: Regular expression to filter by op type.

        Returns:
          A debuggre_cli_common.RichLine object representing the content of the
          cell, potentially with a clickable MenuItem.

        Raises:
          IndexError: if row index is out of range.
        """
        menu_item = None
        if col == 0:
            text = self._profile_datum_list[row].node_exec_stats.node_name
        elif col == 1:
            text = self._profile_datum_list[row].op_type
        elif col == 2:
            text = str(self.formatted_start_time[row])
        elif col == 3:
            text = str(self.formatted_op_time[row])
        elif col == 4:
            text = str(self.formatted_exec_time[row])
        elif col == 5:
            command = "ps"
            if device_name_filter:
                command += " --%s %s" % (_DEVICE_NAME_FILTER_FLAG,
                                         device_name_filter)
            if node_name_filter:
                command += " --%s %s" % (_NODE_NAME_FILTER_FLAG,
                                         node_name_filter)
            if op_type_filter:
                command += " --%s %s" % (_OP_TYPE_FILTER_FLAG, op_type_filter)
            command += " %s --init_line %d" % (
                self._profile_datum_list[row].file_path,
                self._profile_datum_list[row].line_number)
            menu_item = debugger_cli_common.MenuItem(None, command)
            text = self._profile_datum_list[row].file_line_func
        else:
            raise IndexError("Invalid column index %d." % col)

        return RL(text, font_attr=menu_item)
예제 #4
0
    def _run_info_handler(self, args, screen_info=None):
        _ = args  # Currently unused.
        _ = screen_info  # Currently unused.
        output = debugger_cli_common.RichTextLines([])

        if self._run_call_count == 1:
            output.extend(cli_shared.get_tvmdbg_logo())
        output.extend(self._run_info)

        if (not self._is_run_start
                and debugger_cli_common.MAIN_MENU_KEY in output.annotations):
            menu = output.annotations[debugger_cli_common.MAIN_MENU_KEY]
            if "list_tensors" not in menu.captions():
                menu.insert(
                    0,
                    debugger_cli_common.MenuItem("list_tensors",
                                                 "list_tensors"))

        return output
예제 #5
0
def _get_list_profile_lines(device_name,
                            device_index,
                            device_count,
                            profile_datum_list,
                            sort_by,
                            sort_reverse,
                            time_unit,
                            device_name_filter=None,
                            node_name_filter=None,
                            op_type_filter=None,
                            screen_cols=80):
    """Get `RichTextLines` object for list_profile command for a given device.

    Args:
      device_name: (string) Device name.
      device_index: (int) Device index.
      device_count: (int) Number of devices.
      profile_datum_list: List of `ProfileDatum` objects.
      sort_by: (string) Identifier of column to sort. Sort identifier
          must match value of SORT_OPS_BY_OP_NAME, SORT_OPS_BY_OP_TYPE,
          SORT_OPS_BY_EXEC_TIME, SORT_OPS_BY_MEMORY or SORT_OPS_BY_LINE.
      sort_reverse: (bool) Whether to sort in descending instead of default
          (ascending) order.
      time_unit: time unit, must be in cli_shared.TIME_UNITS.
      device_name_filter: Regular expression to filter by device name.
      node_name_filter: Regular expression to filter by node name.
      op_type_filter: Regular expression to filter by op type.
      screen_cols: (int) Number of columns available on the screen (i.e.,
        available screen width).

    Returns:
      `RichTextLines` object containing a table that displays profiling
      information for each op.
    """
    profile_data = ProfileDataTableView(profile_datum_list,
                                        time_unit=time_unit)

    # Calculate total time early to calculate column widths.
    total_op_time = sum(datum.op_time for datum in profile_datum_list)
    total_exec_time = sum(datum.node_exec_stats.all_end_rel_micros
                          for datum in profile_datum_list)
    device_total_row = [
        "Device Total", "",
        cli_shared.time_to_readable_str(total_op_time,
                                        force_time_unit=time_unit),
        cli_shared.time_to_readable_str(total_exec_time,
                                        force_time_unit=time_unit)
    ]

    # Calculate column widths.
    column_widths = [
        len(column_name) for column_name in profile_data.column_names()
    ]
    for i, row in enumerate(device_total_row):
        column_widths[i] = max(column_widths[i], len(row))
    for i, col in enumerate(column_widths):
        for row in range(profile_data.row_count()):
            column_widths[i] = max(
                col,
                len(
                    profile_data.value(row,
                                       i,
                                       device_name_filter=device_name_filter,
                                       node_name_filter=node_name_filter,
                                       op_type_filter=op_type_filter)))
        column_widths[i] += 2  # add margin between columns

    # Add device name.
    output = [RL("-" * screen_cols)]
    device_row = "Device %d of %d: %s" % (device_index + 1, device_count,
                                          device_name)
    output.append(RL(device_row))
    output.append(RL())

    # Add headers.
    base_command = "list_profile"
    row = RL()
    for col in range(profile_data.column_count()):
        column_name = profile_data.column_names()[col]
        sort_id = profile_data.column_sort_id(col)
        command = "%s -s %s" % (base_command, sort_id)
        if sort_by == sort_id and not sort_reverse:
            command += " -r"
        head_menu_item = debugger_cli_common.MenuItem(None, command)
        row += RL(column_name, font_attr=[head_menu_item, "bold"])
        row += RL(" " * (column_widths[col] - len(column_name)))

    output.append(row)

    # Add data rows.
    for row in range(profile_data.row_count()):
        new_row = RL()
        for col in range(profile_data.column_count()):
            new_cell = profile_data.value(
                row,
                col,
                device_name_filter=device_name_filter,
                node_name_filter=node_name_filter,
                op_type_filter=op_type_filter)
            new_row += new_cell
            new_row += RL(" " * (column_widths[col] - len(new_cell)))
        output.append(new_row)

    # Add stat totals.
    row_str = ""
    for i, row in enumerate(device_total_row):
        row_str += ("{:<%d}" % column_widths[i]).format(row)
    output.append(RL())
    output.append(RL(row_str))
    return debugger_cli_common.rich_text_lines_from_rich_line_list(output)
예제 #6
0
    def print_source(self, args, screen_info=None):
        """Print a Python source file with line-level profile information.

        Args:
          args: Command-line arguments, excluding the command prefix, as a list of
            str.
          screen_info: Optional dict input containing screen information such as
            cols.

        Returns:
          Output text lines as a RichTextLines object.
        """
        del screen_info

        parsed = self._arg_parsers["print_source"].parse_args(args)

        device_name_regex = (re.compile(parsed.device_name_filter)
                             if parsed.device_name_filter else None)

        profile_data = []
        data_generator = self._get_profile_data_generator()
        device_count = len(self._run_metadata.step_stats.dev_stats)
        for index in range(device_count):
            device_stats = self._run_metadata.step_stats.dev_stats[index]
            if device_name_regex and not device_name_regex.match(
                    device_stats.device):
                continue
            profile_data.extend(
                [datum for datum in data_generator(device_stats)])

        source_annotation = source_utils.annotate_source_against_profile(
            profile_data,
            os.path.expanduser(parsed.source_file_path),
            node_name_filter=parsed.node_name_filter,
            op_type_filter=parsed.op_type_filter)
        if not source_annotation:
            return debugger_cli_common.RichTextLines([
                "The source file %s does not contain any profile information for "
                "the previous Session run under the following "
                "filters:" % parsed.source_file_path,
                "  --%s: %s" %
                (_DEVICE_NAME_FILTER_FLAG, parsed.device_name_filter),
                "  --%s: %s" %
                (_NODE_NAME_FILTER_FLAG, parsed.node_name_filter),
                "  --%s: %s" % (_OP_TYPE_FILTER_FLAG, parsed.op_type_filter)
            ])

        max_total_cost = 0
        for line_index in source_annotation:
            total_cost = _get_total_cost(source_annotation[line_index],
                                         parsed.cost_type)
            max_total_cost = max(max_total_cost, total_cost)

        source_lines, line_num_width = source_utils.load_source(
            parsed.source_file_path)

        cost_bar_max_length = 10
        total_cost_head = parsed.cost_type
        column_widths = {
            "cost_bar": cost_bar_max_length + 3,
            "total_cost": len(total_cost_head) + 3,
            "num_nodes_execs": len(self._NUM_EXECS_SUB_HEAD) + 1,
            "line_number": line_num_width,
        }

        head = RL(
            " " * column_widths["cost_bar"] + total_cost_head + " " *
            (column_widths["total_cost"] - len(total_cost_head)) +
            self._NUM_NODES_HEAD + " " *
            (column_widths["num_nodes_execs"] - len(self._NUM_NODES_HEAD)),
            font_attr=self._LINE_COST_ATTR)
        head += RL(self._LINENO_HEAD, font_attr=self._LINE_NUM_ATTR)
        sub_head = RL(
            " " * (column_widths["cost_bar"] + column_widths["total_cost"]) +
            self._NUM_EXECS_SUB_HEAD + " " *
            (column_widths["num_nodes_execs"] - len(self._NUM_EXECS_SUB_HEAD))
            + " " * column_widths["line_number"],
            font_attr=self._LINE_COST_ATTR)
        sub_head += RL(self._SOURCE_HEAD, font_attr="bold")
        lines = [head, sub_head]

        output_annotations = {}
        for i, line in enumerate(source_lines):
            lineno = i + 1
            if lineno in source_annotation:
                annotation = source_annotation[lineno]
                cost_bar = self._render_normalized_cost_bar(
                    _get_total_cost(annotation, parsed.cost_type),
                    max_total_cost, cost_bar_max_length)
                annotated_line = cost_bar
                annotated_line += " " * (column_widths["cost_bar"] -
                                         len(cost_bar))

                total_cost = RL(cli_shared.time_to_readable_str(
                    _get_total_cost(annotation, parsed.cost_type),
                    force_time_unit=parsed.time_unit),
                                font_attr=self._LINE_COST_ATTR)
                total_cost += " " * (column_widths["total_cost"] -
                                     len(total_cost))
                annotated_line += total_cost

                file_path_filter = re.escape(parsed.source_file_path) + "$"
                command = "lp --file_path_filter %s --min_lineno %d --max_lineno %d" % (
                    file_path_filter, lineno, lineno + 1)
                if parsed.device_name_filter:
                    command += " --%s %s" % (_DEVICE_NAME_FILTER_FLAG,
                                             parsed.device_name_filter)
                if parsed.node_name_filter:
                    command += " --%s %s" % (_NODE_NAME_FILTER_FLAG,
                                             parsed.node_name_filter)
                if parsed.op_type_filter:
                    command += " --%s %s" % (_OP_TYPE_FILTER_FLAG,
                                             parsed.op_type_filter)
                menu_item = debugger_cli_common.MenuItem(None, command)
                num_nodes_execs = RL(
                    "%d(%d)" %
                    (annotation.node_count, annotation.node_exec_count),
                    font_attr=[self._LINE_COST_ATTR, menu_item])
                num_nodes_execs += " " * (column_widths["num_nodes_execs"] -
                                          len(num_nodes_execs))
                annotated_line += num_nodes_execs
            else:
                annotated_line = RL(" " * sum(column_widths[col_name]
                                              for col_name in column_widths
                                              if col_name != "line_number"))

            line_num_column = RL(" L%d" % (lineno), self._LINE_NUM_ATTR)
            line_num_column += " " * (column_widths["line_number"] -
                                      len(line_num_column))
            annotated_line += line_num_column
            annotated_line += line
            lines.append(annotated_line)

            if parsed.init_line == lineno:
                output_annotations[
                    debugger_cli_common.INIT_SCROLL_POS_KEY] = len(lines) - 1

        return debugger_cli_common.rich_text_lines_from_rich_line_list(
            lines, annotations=output_annotations)
예제 #7
0
def get_run_start_intro(run_call_count,
                        fetches,
                        feed_dict,
                        tensor_filters,
                        is_callable_runner=False):
    """Generate formatted intro for run-start UI.

    Args:
      run_call_count: (int) Run call counter.
      fetches: Fetches of the `GraphRuntime.run()` call. See doc of `GraphRuntime.run()`
        for more details.
      feed_dict: Feeds to the `GraphRuntime.run()` call. See doc of `GraphRuntime.run()`
        for more details.
      tensor_filters: (dict) A dict from tensor-filter name to tensor-filter
        callable.
      is_callable_runner: (bool) whether a runner returned by
          GraphRuntime.make_callable is being run.

    Returns:
      (RichTextLines) Formatted intro message about the `GraphRuntime.run()` call.
    """

    fetch_lines = common.get_flattened_names(fetches)

    if not feed_dict:
        feed_dict_lines = [debugger_cli_common.RichLine("  (Empty)")]
    else:
        feed_dict_lines = []
        for feed_key in feed_dict:
            feed_key_name = common.get_graph_element_name(feed_key)
            feed_dict_line = debugger_cli_common.RichLine("  ")
            feed_dict_line += debugger_cli_common.RichLine(
                feed_key_name,
                debugger_cli_common.MenuItem(None, "pf '%s'" % feed_key_name))
            # Surround the name string with quotes, because feed_key_name may contain
            # spaces in some cases, e.g., SparseTensors.
            feed_dict_lines.append(feed_dict_line)
    feed_dict_lines = debugger_cli_common.rich_text_lines_from_rich_line_list(
        feed_dict_lines)

    out = debugger_cli_common.RichTextLines(_HORIZONTAL_BAR)
    if is_callable_runner:
        out.append(
            " Running a runner returned by GraphRuntime.make_callable()")
    else:
        out.append(" GraphRuntime.run() call #%d:" % run_call_count)
        out.append("")
        out.append(" Output:")
        out.extend(
            debugger_cli_common.RichTextLines(
                ["   " + line for line in fetch_lines]))
        out.append("")
        out.append(" Inputs:")
        out.extend(feed_dict_lines)
    out.append(_HORIZONTAL_BAR)
    out.append("")
    out.append(" Select one of the following commands to proceed ---->")

    out.extend(
        _recommend_command("run",
                           "Execute the run() call with debug tensor-watching",
                           create_link=True))
    out.extend(
        _recommend_command(
            "run -n",
            "Execute the run() call without debug tensor-watching",
            create_link=True))
    out.extend(
        _recommend_command(
            "run -t <T>",
            "Execute run() calls (T - 1) times without debugging, then "
            "execute run() once more with debugging and drop back to the CLI"))
    out.extend(
        _recommend_command(
            "run -f <filter_name>",
            "Keep executing run() calls until a dumped tensor passes a given, "
            "registered filter (conditional breakpoint mode)"))

    more_lines = ["    Registered filter(s):"]
    if tensor_filters:
        filter_names = []
        for filter_name in tensor_filters:
            filter_names.append(filter_name)
            command_menu_node = debugger_cli_common.MenuItem(
                "", "run -f %s" % filter_name)
            more_lines.append(
                RL("        * ") + RL(filter_name, command_menu_node))
    else:
        more_lines.append("        (None)")

    out.extend(
        debugger_cli_common.rich_text_lines_from_rich_line_list(more_lines))

    # TODO(Pariksheet): Python invoke_stepper implementation not support now.
    #    out.extend(
    #        _recommend_command(
    #            "invoke_stepper",
    #            "Use the node-stepper interface, which allows you to interactively "
    #            "step through nodes involved in the graph run() call and "
    #            "inspect/modify their values", create_link=True))

    out.append("")

    #    out.append_rich_line(RL("For more details, see ") +
    #                         RL("help.", debugger_cli_common.MenuItem("", "help")) +
    #                         ".")
    #    out.append("")

    # Make main menu for the run-start intro.
    menu = debugger_cli_common.Menu()
    menu.append(debugger_cli_common.MenuItem("run", "run"))
    # TODO(Pariksheet): Python invoke_stepper implementation not support now.
    #    menu.append(debugger_cli_common.MenuItem(
    #        "invoke_stepper", "invoke_stepper"))
    menu.append(debugger_cli_common.MenuItem("exit", "exit"))
    out.annotations[debugger_cli_common.MAIN_MENU_KEY] = menu

    return out