예제 #1
0
    def list_sorted_nodes(self, args, screen_info=None):
        """List the sorted transitive closure of the stepper's outputs."""

        # TODO(cais): Use pattern such as del args, del screen_info python/debug.
        _ = args
        _ = screen_info

        parsed = self.arg_parsers["list_sorted_nodes"].parse_args(args)

        if parsed.lower_bound != -1 and parsed.upper_bound != -1:
            index_range = [
                max(0, parsed.lower_bound),
                min(len(self._sorted_nodes), parsed.upper_bound)
            ]
            verbose = False
        else:
            index_range = [0, len(self._sorted_nodes)]
            verbose = True

        handle_node_names = self._node_stepper.handle_node_names()
        intermediate_tensor_names = self._node_stepper.intermediate_tensor_names(
        )
        override_names = self._node_stepper.override_names()
        dirty_variable_names = [
            dirty_variable.split(":")[0]
            for dirty_variable in self._node_stepper.dirty_variables()
        ]

        lines = []
        if verbose:
            lines.extend([
                "Topologically-sorted transitive input(s) and output(s):", ""
            ])

        for i, element_name in enumerate(self._sorted_nodes):
            if i < index_range[0] or i >= index_range[1]:
                continue

            # TODO(cais): Use fixed-width text to show node index.
            if i == self._next:
                node_prefix = RL("  ") + RL(self.NEXT_NODE_POINTER_STR, "bold")
            else:
                node_prefix = RL("     ")

            node_prefix += "(%d / %d)" % (i + 1, len(
                self._sorted_nodes)) + "  ["
            node_prefix += self._get_status_labels(element_name,
                                                   handle_node_names,
                                                   intermediate_tensor_names,
                                                   override_names,
                                                   dirty_variable_names)

            lines.append(node_prefix + "] " + element_name)

        output = debugger_cli_common.rich_text_lines_from_rich_line_list(lines)

        if verbose:
            output.extend(self._node_status_label_legend())

        return output
예제 #2
0
    def _node_status_label_legend(self):
        """Get legend for node-status labels.

        Returns:
          (debugger_cli_common.RichTextLines) Legend text.
        """

        return debugger_cli_common.rich_text_lines_from_rich_line_list([
            "", "Legend:",
            (RL("  ") + RL(self.STATE_IS_PLACEHOLDER,
                           self._STATE_COLORS[self.STATE_IS_PLACEHOLDER]) +
             " - Placeholder"),
            (RL("  ") + RL(self.STATE_CANT_INPUT,
                           self._STATE_COLORS[self.STATE_CANT_INPUT]) +
             " - Uninputable"),
            (RL("  ") +
             RL(self.STATE_CONT, self._STATE_COLORS[self.STATE_CONT]) +
             " - Already continued-to; Tensor handle available from output "
             "slot(s)"),
            (RL("  ") +
             RL(self.STATE_DUMPED_INTERMEDIATE,
                self._STATE_COLORS[self.STATE_DUMPED_INTERMEDIATE]) +
             " - Uninputable"),
            (RL("  ") + RL(self.STATE_OVERRIDDEN,
                           self._STATE_COLORS[self.STATE_OVERRIDDEN]) +
             " - Has overriding (injected) tensor value"),
            (RL("  ") + RL(self.STATE_DIRTY_VARIABLE,
                           self._STATE_COLORS[self.STATE_DIRTY_VARIABLE]) +
             " - Dirty variable: Variable already updated this node stepper.")
        ])
예제 #3
0
    def _counts_summary(counts, skip_zeros=True, total_count=None):
        """Format values as a two-row table."""
        if skip_zeros:
            counts = [(count_key, count_val) for count_key, count_val in counts
                      if count_val]
        max_common_len = 0
        for count_key, count_val in counts:
            count_val_str = str(count_val)
            common_len = max(len(count_key) + 1, len(count_val_str) + 1)
            max_common_len = max(common_len, max_common_len)

        key_line = debugger_cli_common.RichLine("|")
        val_line = debugger_cli_common.RichLine("|")
        for count_key, count_val in counts:
            count_val_str = str(count_val)
            key_line += _pad_string_to_length(count_key, max_common_len)
            val_line += _pad_string_to_length(count_val_str, max_common_len)
        key_line += " |"
        val_line += " |"

        if total_count is not None:
            total_key_str = "total"
            total_val_str = str(total_count)
            max_common_len = max(len(total_key_str) + 1, len(total_val_str))
            total_key_str = _pad_string_to_length(total_key_str,
                                                  max_common_len)
            total_val_str = _pad_string_to_length(total_val_str,
                                                  max_common_len)
            key_line += total_key_str + " |"
            val_line += total_val_str + " |"

        return debugger_cli_common.rich_text_lines_from_rich_line_list(
            [key_line, val_line])
예제 #4
0
def _recommend_command(command, description, indent=2, create_link=False):
    """Generate a RichTextLines object that describes a recommended command.

    Args:
      command: (str) The command to recommend.
      description: (str) A description of what the command does.
      indent: (int) How many spaces to indent in the beginning.
      create_link: (bool) Whether a command link is to be applied to the command
        string.

    Returns:
      (RichTextLines) Formatted text (with font attributes) for recommending the
        command.
    """

    indent_str = " " * indent

    if create_link:
        font_attr = [debugger_cli_common.MenuItem("", command), "bold"]
    else:
        font_attr = "bold"

    lines = [
        RL(indent_str) + RL(command, font_attr) + ":",
        indent_str + "   " + description
    ]

    return debugger_cli_common.rich_text_lines_from_rich_line_list(lines)
예제 #5
0
def get_error_intro(tvm_error):
    """Generate formatted intro for TVM run-time error.

    Args:
      tvm_error: (errors.OpError) TVM run-time error object.

    Returns:
      (RichTextLines) Formatted intro message about the run-time OpError, with
        sample commands for debugging.
    """

    op_name = tvm_error.op.name

    intro_lines = [
        "--------------------------------------",
        RL("!!! An error occurred during the run !!!", "blink"),
        "",
        "You may use the following commands to debug:",
    ]

    out = debugger_cli_common.rich_text_lines_from_rich_line_list(intro_lines)

    out.extend(
        _recommend_command("ni -a -d -t %s" % op_name,
                           "Inspect information about the failing op.",
                           create_link=True))
    out.extend(
        _recommend_command("li -r %s" % op_name,
                           "List inputs to the failing op, recursively.",
                           create_link=True))

    out.extend(
        _recommend_command(
            "lt",
            "List all tensors dumped during the failing run() call.",
            create_link=True))

    more_lines = [
        "",
        "Op name:    " + op_name,
        "Error type: " + str(type(tvm_error)),
        "",
        "Details:",
        str(tvm_error),
        "",
        "WARNING: Using client GraphDef due to the error, instead of "
        "executor GraphDefs.",
        "--------------------------------------",
        "",
    ]

    out.extend(debugger_cli_common.RichTextLines(more_lines))

    return out
예제 #6
0
def error(msg):
    """Generate a RichTextLines output for error.

    Args:
      msg: (str) The error message.

    Returns:
      (debugger_cli_common.RichTextLines) A representation of the error message
        for screen output.
    """

    return debugger_cli_common.rich_text_lines_from_rich_line_list(
        [RL("ERROR: " + msg, COLOR_RED)])
예제 #7
0
def get_tvmdbg_logo():
    """Make an ASCII representation of the tvmdbg logo."""

    lines = [RL(" ", COLOR_GRAY)]
    lines.append(
        RL("@@@@@@@@@  ", COLOR_BLUE) + RL("@@@@@@@     ", COLOR_GRAY) + RL(
            "@@@                                                          @@@  @@@            @@@@        ",
            COLOR_GRAY))
    lines.append(
        RL("@@@@@@@@@  ", COLOR_BLUE) + RL("@@@@@@@     ", COLOR_GRAY) + RL(
            "@@@                                                          @@@  @@@          @@@@@@@@@@@   ",
            COLOR_GRAY))
    lines.append(
        RL("@@@@@@@@@  ", COLOR_BLUE) + RL("@@@@@@@     ", COLOR_GRAY) + RL(
            "@@@@@@@@@  @@@         @@@  @@@   @@@        @@@             @@@  @@@         @@     @@      ",
            COLOR_GRAY))
    lines.append(
        RL("@@@@@@@@@   ", COLOR_BLUE) + RL("@@@@@      ", COLOR_GRAY) + RL(
            "@@@@@@@@@   @@@        @@@  @@@@@@@@@@@@@@@@@@@@@@@          @@@  @@@         @@     @@      ",
            COLOR_GRAY))
    lines.append(
        RL("@@@@@@@@@@@            ", COLOR_BLUE) + RL(
            "@@@          @@@      @@@   @@@@@   @@@@@@@@   @@@@@    @@@@ @@@  @@@ @@@@    @@@@@@@@@      ",
            COLOR_GRAY))
    lines.append(
        RL("@@@@@@@@@@@@@@@@@@@    ", COLOR_BLUE) + RL(
            "@@@          @@@      @@@   @@@@      @@@@@      @@@  @@@@@@@@@@  @@@@@@@@@     @@@@@        ",
            COLOR_GRAY))
    lines.append(
        RL("        @@@@@@@@@@@    ", COLOR_BLUE) + RL(
            "@@@           @@@    @@@    @@@        @@@       @@@  @@@    @@@  @@@    @@@  @@             ",
            COLOR_GRAY))
    lines.append(
        RL("  @@@@@", COLOR_GRAY) + RL("   @@@@@@@@@    ", COLOR_BLUE) + RL(
            "@@@           @@@    @@@    @@@        @@@       @@@  @@@    @@@  @@@    @@@    @@@@@@@@@    ",
            COLOR_GRAY))
    lines.append(
        RL(" @@@@@@@", COLOR_GRAY) + RL("  @@@@@@@@@    ", COLOR_BLUE) + RL(
            "@@@@           @@@  @@@     @@@        @@@       @@@  @@@    @@@  @@@    @@@   @@        @@  ",
            COLOR_GRAY))
    lines.append(
        RL(" @@@@@@@", COLOR_GRAY) + RL("  @@@@@@@@@    ", COLOR_BLUE) + RL(
            "@@@@@@@@@       @@@@@@      @@@        @@@       @@@  @@@@@@@@@@  @@@@@@@@@    @@@@@@@@@@@   ",
            COLOR_GRAY))
    lines.append(
        RL(" @@@@@@@", COLOR_GRAY) + RL("  @@@@@@@@@    ", COLOR_BLUE) + RL(
            " @@@@@@@@        @@@@       @@@        @@@       @@@    @@@@ @@@  @@@ @@@@      @@@@@@@@@    ",
            COLOR_GRAY))
    lines.append(RL(" ", COLOR_GRAY))

    return debugger_cli_common.rich_text_lines_from_rich_line_list(lines)
예제 #8
0
    def render(self,
               max_length,
               backward_command,
               forward_command,
               latest_command_attribute="black_on_white",
               old_command_attribute="magenta_on_white"):
        """Render the rich text content of the single-line navigation bar.

        Args:
          max_length: (`int`) Maximum length of the navigation bar, in characters.
          backward_command: (`str`) command for going backward. Used to construct
            the shortcut menu item.
          forward_command: (`str`) command for going forward. Used to construct the
            shortcut menu item.
           latest_command_attribute: font attribute for lastest command.
           old_command_attribute: font attribute for old (non-latest) command.

        Returns:
          (`debugger_cli_common.RichTextLines`) the navigation bar text with
            attributes.

        """
        output = RL("| ")
        output += RL(self.BACK_ARROW_TEXT, (debugger_cli_common.MenuItem(
            None, backward_command) if self.can_go_back() else None))
        output += RL(" ")
        output += RL(self.FORWARD_ARROW_TEXT, (debugger_cli_common.MenuItem(
            None, forward_command) if self.can_go_forward() else None))

        if self._items:
            command_attribute = (latest_command_attribute if
                                 (self._pointer == (len(self._items) - 1)) else
                                 old_command_attribute)
            output += RL(" | ")
            if self._pointer != len(self._items) - 1:
                output += RL("(-%d) " % (len(self._items) - 1 - self._pointer),
                             command_attribute)

            if len(output) < max_length:
                maybe_truncated_command = self._items[self._pointer].command[:(
                    max_length - len(output))]
                output += RL(maybe_truncated_command, command_attribute)

        return debugger_cli_common.rich_text_lines_from_rich_line_list(
            [output])
예제 #9
0
    def summarize(self, highlight=None):
        """Get a text summary of the config.

        Args:
          highlight: A property name to highlight in the output.

        Returns:
          A `RichTextLines` output.
        """
        lines = [RL("Command-line configuration:", "bold"), RL("")]
        for name, val in self._config.items():
            highlight_attr = "bold" if name == highlight else None
            line = RL("  ")
            line += RL(name, ["underline", highlight_attr])
            line += RL(": ")
            line += RL(str(val), font_attr=highlight_attr)
            lines.append(line)
        return debugger_cli_common.rich_text_lines_from_rich_line_list(lines)
예제 #10
0
    def _report_last_updated(self):
        """Generate a report of the variables updated in the last cont/step call.

        Returns:
          (debugger_cli_common.RichTextLines) A RichTextLines representation of the
            variables updated in the last cont/step call.
        """

        last_updated = self._node_stepper.last_updated()
        if not last_updated:
            return debugger_cli_common.RichTextLines([])

        rich_lines = [RL("Updated:", self._UPDATED_ATTRIBUTE)]
        sorted_last_updated = sorted(list(last_updated))
        for updated in sorted_last_updated:
            rich_lines.append("  %s" % updated)
        rich_lines.append("")
        return debugger_cli_common.rich_text_lines_from_rich_line_list(rich_lines)
예제 #11
0
    def _report_last_input_types(self):
        """Generate a report of the input types used in the cont/step call.

        Returns:
          (debugger_cli_common.RichTextLines) A RichTextLines representation of the
            inputs used in the last cont/step call.
        """
        input_types = self._node_stepper.last_input_types()

        out = ["Stepper used inputs:"]
        if input_types:
            for input_name in input_types:
                input_info = RL("  %s : " % input_name)
                input_info += RL(input_types[input_name],
                                 self._INPUT_COLORS[input_types[input_name]])
                out.append(input_info)
        else:
            out.append("  (No inputs)")
        out.append("")

        return debugger_cli_common.rich_text_lines_from_rich_line_list(out)
예제 #12
0
    def _report_last_feed_types(self):
        """Generate a report of the feed types used in the cont/step call.

        Returns:
          (debugger_cli_common.RichTextLines) A RichTextLines representation of the
            feeds used in the last cont/step call.
        """
        feed_types = self._node_stepper.last_feed_types()

        out = ["Stepper used feeds:"]
        if feed_types:
            for feed_name in feed_types:
                feed_info = RL("  %s : " % feed_name)
                feed_info += RL(feed_types[feed_name],
                                self._FEED_COLORS[feed_types[feed_name]])
                out.append(feed_info)
        else:
            out.append("  (No feeds)")
        out.append("")

        return debugger_cli_common.rich_text_lines_from_rich_line_list(out)
예제 #13
0
def _get_list_profile_lines(device_name,
                            device_index,
                            device_count,
                            profile_datum_list,
                            sort_by,
                            sort_reverse,
                            time_unit,
                            device_name_filter=None,
                            node_name_filter=None,
                            op_type_filter=None,
                            screen_cols=80):
    """Get `RichTextLines` object for list_profile command for a given device.

    Args:
      device_name: (string) Device name.
      device_index: (int) Device index.
      device_count: (int) Number of devices.
      profile_datum_list: List of `ProfileDatum` objects.
      sort_by: (string) Identifier of column to sort. Sort identifier
          must match value of SORT_OPS_BY_OP_NAME, SORT_OPS_BY_OP_TYPE,
          SORT_OPS_BY_EXEC_TIME, SORT_OPS_BY_MEMORY or SORT_OPS_BY_LINE.
      sort_reverse: (bool) Whether to sort in descending instead of default
          (ascending) order.
      time_unit: time unit, must be in cli_shared.TIME_UNITS.
      device_name_filter: Regular expression to filter by device name.
      node_name_filter: Regular expression to filter by node name.
      op_type_filter: Regular expression to filter by op type.
      screen_cols: (int) Number of columns available on the screen (i.e.,
        available screen width).

    Returns:
      `RichTextLines` object containing a table that displays profiling
      information for each op.
    """
    profile_data = ProfileDataTableView(profile_datum_list,
                                        time_unit=time_unit)

    # Calculate total time early to calculate column widths.
    total_op_time = sum(datum.op_time for datum in profile_datum_list)
    total_exec_time = sum(datum.node_exec_stats.all_end_rel_micros
                          for datum in profile_datum_list)
    device_total_row = [
        "Device Total", "",
        cli_shared.time_to_readable_str(total_op_time,
                                        force_time_unit=time_unit),
        cli_shared.time_to_readable_str(total_exec_time,
                                        force_time_unit=time_unit)
    ]

    # Calculate column widths.
    column_widths = [
        len(column_name) for column_name in profile_data.column_names()
    ]
    for i, row in enumerate(device_total_row):
        column_widths[i] = max(column_widths[i], len(row))
    for i, col in enumerate(column_widths):
        for row in range(profile_data.row_count()):
            column_widths[i] = max(
                col,
                len(
                    profile_data.value(row,
                                       i,
                                       device_name_filter=device_name_filter,
                                       node_name_filter=node_name_filter,
                                       op_type_filter=op_type_filter)))
        column_widths[i] += 2  # add margin between columns

    # Add device name.
    output = [RL("-" * screen_cols)]
    device_row = "Device %d of %d: %s" % (device_index + 1, device_count,
                                          device_name)
    output.append(RL(device_row))
    output.append(RL())

    # Add headers.
    base_command = "list_profile"
    row = RL()
    for col in range(profile_data.column_count()):
        column_name = profile_data.column_names()[col]
        sort_id = profile_data.column_sort_id(col)
        command = "%s -s %s" % (base_command, sort_id)
        if sort_by == sort_id and not sort_reverse:
            command += " -r"
        head_menu_item = debugger_cli_common.MenuItem(None, command)
        row += RL(column_name, font_attr=[head_menu_item, "bold"])
        row += RL(" " * (column_widths[col] - len(column_name)))

    output.append(row)

    # Add data rows.
    for row in range(profile_data.row_count()):
        new_row = RL()
        for col in range(profile_data.column_count()):
            new_cell = profile_data.value(
                row,
                col,
                device_name_filter=device_name_filter,
                node_name_filter=node_name_filter,
                op_type_filter=op_type_filter)
            new_row += new_cell
            new_row += RL(" " * (column_widths[col] - len(new_cell)))
        output.append(new_row)

    # Add stat totals.
    row_str = ""
    for i, row in enumerate(device_total_row):
        row_str += ("{:<%d}" % column_widths[i]).format(row)
    output.append(RL())
    output.append(RL(row_str))
    return debugger_cli_common.rich_text_lines_from_rich_line_list(output)
예제 #14
0
    def print_source(self, args, screen_info=None):
        """Print a Python source file with line-level profile information.

        Args:
          args: Command-line arguments, excluding the command prefix, as a list of
            str.
          screen_info: Optional dict input containing screen information such as
            cols.

        Returns:
          Output text lines as a RichTextLines object.
        """
        del screen_info

        parsed = self._arg_parsers["print_source"].parse_args(args)

        device_name_regex = (re.compile(parsed.device_name_filter)
                             if parsed.device_name_filter else None)

        profile_data = []
        data_generator = self._get_profile_data_generator()
        device_count = len(self._run_metadata.step_stats.dev_stats)
        for index in range(device_count):
            device_stats = self._run_metadata.step_stats.dev_stats[index]
            if device_name_regex and not device_name_regex.match(
                    device_stats.device):
                continue
            profile_data.extend(
                [datum for datum in data_generator(device_stats)])

        source_annotation = source_utils.annotate_source_against_profile(
            profile_data,
            os.path.expanduser(parsed.source_file_path),
            node_name_filter=parsed.node_name_filter,
            op_type_filter=parsed.op_type_filter)
        if not source_annotation:
            return debugger_cli_common.RichTextLines([
                "The source file %s does not contain any profile information for "
                "the previous Session run under the following "
                "filters:" % parsed.source_file_path,
                "  --%s: %s" %
                (_DEVICE_NAME_FILTER_FLAG, parsed.device_name_filter),
                "  --%s: %s" %
                (_NODE_NAME_FILTER_FLAG, parsed.node_name_filter),
                "  --%s: %s" % (_OP_TYPE_FILTER_FLAG, parsed.op_type_filter)
            ])

        max_total_cost = 0
        for line_index in source_annotation:
            total_cost = _get_total_cost(source_annotation[line_index],
                                         parsed.cost_type)
            max_total_cost = max(max_total_cost, total_cost)

        source_lines, line_num_width = source_utils.load_source(
            parsed.source_file_path)

        cost_bar_max_length = 10
        total_cost_head = parsed.cost_type
        column_widths = {
            "cost_bar": cost_bar_max_length + 3,
            "total_cost": len(total_cost_head) + 3,
            "num_nodes_execs": len(self._NUM_EXECS_SUB_HEAD) + 1,
            "line_number": line_num_width,
        }

        head = RL(
            " " * column_widths["cost_bar"] + total_cost_head + " " *
            (column_widths["total_cost"] - len(total_cost_head)) +
            self._NUM_NODES_HEAD + " " *
            (column_widths["num_nodes_execs"] - len(self._NUM_NODES_HEAD)),
            font_attr=self._LINE_COST_ATTR)
        head += RL(self._LINENO_HEAD, font_attr=self._LINE_NUM_ATTR)
        sub_head = RL(
            " " * (column_widths["cost_bar"] + column_widths["total_cost"]) +
            self._NUM_EXECS_SUB_HEAD + " " *
            (column_widths["num_nodes_execs"] - len(self._NUM_EXECS_SUB_HEAD))
            + " " * column_widths["line_number"],
            font_attr=self._LINE_COST_ATTR)
        sub_head += RL(self._SOURCE_HEAD, font_attr="bold")
        lines = [head, sub_head]

        output_annotations = {}
        for i, line in enumerate(source_lines):
            lineno = i + 1
            if lineno in source_annotation:
                annotation = source_annotation[lineno]
                cost_bar = self._render_normalized_cost_bar(
                    _get_total_cost(annotation, parsed.cost_type),
                    max_total_cost, cost_bar_max_length)
                annotated_line = cost_bar
                annotated_line += " " * (column_widths["cost_bar"] -
                                         len(cost_bar))

                total_cost = RL(cli_shared.time_to_readable_str(
                    _get_total_cost(annotation, parsed.cost_type),
                    force_time_unit=parsed.time_unit),
                                font_attr=self._LINE_COST_ATTR)
                total_cost += " " * (column_widths["total_cost"] -
                                     len(total_cost))
                annotated_line += total_cost

                file_path_filter = re.escape(parsed.source_file_path) + "$"
                command = "lp --file_path_filter %s --min_lineno %d --max_lineno %d" % (
                    file_path_filter, lineno, lineno + 1)
                if parsed.device_name_filter:
                    command += " --%s %s" % (_DEVICE_NAME_FILTER_FLAG,
                                             parsed.device_name_filter)
                if parsed.node_name_filter:
                    command += " --%s %s" % (_NODE_NAME_FILTER_FLAG,
                                             parsed.node_name_filter)
                if parsed.op_type_filter:
                    command += " --%s %s" % (_OP_TYPE_FILTER_FLAG,
                                             parsed.op_type_filter)
                menu_item = debugger_cli_common.MenuItem(None, command)
                num_nodes_execs = RL(
                    "%d(%d)" %
                    (annotation.node_count, annotation.node_exec_count),
                    font_attr=[self._LINE_COST_ATTR, menu_item])
                num_nodes_execs += " " * (column_widths["num_nodes_execs"] -
                                          len(num_nodes_execs))
                annotated_line += num_nodes_execs
            else:
                annotated_line = RL(" " * sum(column_widths[col_name]
                                              for col_name in column_widths
                                              if col_name != "line_number"))

            line_num_column = RL(" L%d" % (lineno), self._LINE_NUM_ATTR)
            line_num_column += " " * (column_widths["line_number"] -
                                      len(line_num_column))
            annotated_line += line_num_column
            annotated_line += line
            lines.append(annotated_line)

            if parsed.init_line == lineno:
                output_annotations[
                    debugger_cli_common.INIT_SCROLL_POS_KEY] = len(lines) - 1

        return debugger_cli_common.rich_text_lines_from_rich_line_list(
            lines, annotations=output_annotations)
예제 #15
0
def get_run_start_intro(run_call_count,
                        fetches,
                        feed_dict,
                        tensor_filters,
                        is_callable_runner=False):
    """Generate formatted intro for run-start UI.

    Args:
      run_call_count: (int) Run call counter.
      fetches: Fetches of the `GraphRuntime.run()` call. See doc of `GraphRuntime.run()`
        for more details.
      feed_dict: Feeds to the `GraphRuntime.run()` call. See doc of `GraphRuntime.run()`
        for more details.
      tensor_filters: (dict) A dict from tensor-filter name to tensor-filter
        callable.
      is_callable_runner: (bool) whether a runner returned by
          GraphRuntime.make_callable is being run.

    Returns:
      (RichTextLines) Formatted intro message about the `GraphRuntime.run()` call.
    """

    fetch_lines = common.get_flattened_names(fetches)

    if not feed_dict:
        feed_dict_lines = [debugger_cli_common.RichLine("  (Empty)")]
    else:
        feed_dict_lines = []
        for feed_key in feed_dict:
            feed_key_name = common.get_graph_element_name(feed_key)
            feed_dict_line = debugger_cli_common.RichLine("  ")
            feed_dict_line += debugger_cli_common.RichLine(
                feed_key_name,
                debugger_cli_common.MenuItem(None, "pf '%s'" % feed_key_name))
            # Surround the name string with quotes, because feed_key_name may contain
            # spaces in some cases, e.g., SparseTensors.
            feed_dict_lines.append(feed_dict_line)
    feed_dict_lines = debugger_cli_common.rich_text_lines_from_rich_line_list(
        feed_dict_lines)

    out = debugger_cli_common.RichTextLines(_HORIZONTAL_BAR)
    if is_callable_runner:
        out.append(
            " Running a runner returned by GraphRuntime.make_callable()")
    else:
        out.append(" GraphRuntime.run() call #%d:" % run_call_count)
        out.append("")
        out.append(" Output:")
        out.extend(
            debugger_cli_common.RichTextLines(
                ["   " + line for line in fetch_lines]))
        out.append("")
        out.append(" Inputs:")
        out.extend(feed_dict_lines)
    out.append(_HORIZONTAL_BAR)
    out.append("")
    out.append(" Select one of the following commands to proceed ---->")

    out.extend(
        _recommend_command("run",
                           "Execute the run() call with debug tensor-watching",
                           create_link=True))
    out.extend(
        _recommend_command(
            "run -n",
            "Execute the run() call without debug tensor-watching",
            create_link=True))
    out.extend(
        _recommend_command(
            "run -t <T>",
            "Execute run() calls (T - 1) times without debugging, then "
            "execute run() once more with debugging and drop back to the CLI"))
    out.extend(
        _recommend_command(
            "run -f <filter_name>",
            "Keep executing run() calls until a dumped tensor passes a given, "
            "registered filter (conditional breakpoint mode)"))

    more_lines = ["    Registered filter(s):"]
    if tensor_filters:
        filter_names = []
        for filter_name in tensor_filters:
            filter_names.append(filter_name)
            command_menu_node = debugger_cli_common.MenuItem(
                "", "run -f %s" % filter_name)
            more_lines.append(
                RL("        * ") + RL(filter_name, command_menu_node))
    else:
        more_lines.append("        (None)")

    out.extend(
        debugger_cli_common.rich_text_lines_from_rich_line_list(more_lines))

    # TODO(Pariksheet): Python invoke_stepper implementation not support now.
    #    out.extend(
    #        _recommend_command(
    #            "invoke_stepper",
    #            "Use the node-stepper interface, which allows you to interactively "
    #            "step through nodes involved in the graph run() call and "
    #            "inspect/modify their values", create_link=True))

    out.append("")

    #    out.append_rich_line(RL("For more details, see ") +
    #                         RL("help.", debugger_cli_common.MenuItem("", "help")) +
    #                         ".")
    #    out.append("")

    # Make main menu for the run-start intro.
    menu = debugger_cli_common.Menu()
    menu.append(debugger_cli_common.MenuItem("run", "run"))
    # TODO(Pariksheet): Python invoke_stepper implementation not support now.
    #    menu.append(debugger_cli_common.MenuItem(
    #        "invoke_stepper", "invoke_stepper"))
    menu.append(debugger_cli_common.MenuItem("exit", "exit"))
    out.annotations[debugger_cli_common.MAIN_MENU_KEY] = menu

    return out
예제 #16
0
def format_tensor(tensor,
                  tensor_name,
                  np_printoptions,
                  print_all=False,
                  tensor_slicing=None,
                  highlight_options=None,
                  include_numeric_summary=False,
                  write_path=None):
    """Generate formatted str to represent a tensor or its slices.

    Args:
      tensor: (numpy ndarray) The tensor value.
      tensor_name: (str) Name of the tensor, e.g., the tensor's debug watch key.
      np_printoptions: (dict) Numpy tensor formatting options.
      print_all: (bool) Whether the tensor is to be displayed in its entirety,
        instead of printing ellipses, even if its number of elements exceeds
        the default numpy display threshold.
        (Note: Even if this is set to true, the screen output can still be cut
         off by the UI frontend if it consist of more lines than the frontend
         can handle.)
      tensor_slicing: (str or None) Slicing of the tensor, e.g., "[:, 1]". If
        None, no slicing will be performed on the tensor.
      highlight_options: (tensor_format.HighlightOptions) options to highlight
        elements of the tensor. See the doc of tensor_format.format_tensor()
        for more details.
      include_numeric_summary: Whether a text summary of the numeric values (if
        applicable) will be included.
      write_path: A path to save the tensor value (after any slicing) to
        (optional). `numpy.save()` is used to save the value.

    Returns:
      An instance of `debugger_cli_common.RichTextLines` representing the
      (potentially sliced) tensor.
    """

    if tensor_slicing:
        # Validate the indexing.
        value = command_parser.evaluate_tensor_slice(tensor, tensor_slicing)
        sliced_name = tensor_name + tensor_slicing
    else:
        value = tensor
        sliced_name = tensor_name

    auxiliary_message = None
    if write_path:
        with open(write_path, "wb") as output_file:
            np.save(output_file, value)
        line = debugger_cli_common.RichLine("Saved value to: ")
        line += debugger_cli_common.RichLine(write_path, font_attr="bold")
        line += " (%sB)" % bytes_to_readable_str(os.stat(write_path).st_size)
        auxiliary_message = debugger_cli_common.rich_text_lines_from_rich_line_list(
            [line, debugger_cli_common.RichLine("")])

    if print_all:
        np_printoptions["threshold"] = value.size
    else:
        np_printoptions["threshold"] = DEFAULT_NDARRAY_DISPLAY_THRESHOLD

    return tensor_format.format_tensor(
        value,
        sliced_name,
        include_metadata=True,
        include_numeric_summary=include_numeric_summary,
        auxiliary_message=auxiliary_message,
        np_printoptions=np_printoptions,
        highlight_options=highlight_options)