def testFormatAsSingleLineWithDisabledNode(self):
        node2 = debugger_cli_common.MenuItem("write poem",
                                             "write_poem",
                                             enabled=False)
        self.menu.append(node2)

        output = self.menu.format_as_single_line(prefix="Menu: ",
                                                 divider=", ",
                                                 disabled_item_attrs="bold")
        self.assertEqual(
            ["Menu: water flower, measure wavelength, write poem, "],
            output.lines)
        self.assertEqual((6, 18, [self.node1]), output.font_attr_segs[0][0])
        self.assertEqual((20, 38, [self.node2]), output.font_attr_segs[0][1])
        self.assertEqual((40, 50, ["bold"]), output.font_attr_segs[0][2])
  def _run_info_handler(self, args, screen_info=None):
    output = debugger_cli_common.RichTextLines([])

    if self._run_call_count == 1:
      output.extend(cli_shared.get_tfdbg_logo())
    output.extend(self._run_info)

    if (not self._is_run_start and
        debugger_cli_common.MAIN_MENU_KEY in output.annotations):
      menu = output.annotations[debugger_cli_common.MAIN_MENU_KEY]
      if "list_tensors" not in menu.captions():
        menu.insert(
            0, debugger_cli_common.MenuItem("list_tensors", "list_tensors"))

    return output
예제 #3
0
    def _list_node_dumps(self, node_name):
        """List dumped tensor data from a node.

    Args:
      node_name: Name of the node of which the attributes are to be listed.

    Returns:
      A RichTextLines object.
    """

        lines = []
        font_attr_segs = {}

        watch_keys = self._debug_dump.debug_watch_keys(node_name)

        dump_count = 0
        for watch_key in watch_keys:
            debug_tensor_data = self._debug_dump.watch_key_to_data(watch_key)
            for datum in debug_tensor_data:
                line = "  Slot %d @ %s @ %.3f ms" % (
                    datum.output_slot, datum.debug_op,
                    (datum.timestamp - self._debug_dump.t0) / 1000.0)
                lines.append(line)
                command = "pt %s:%d -n %d" % (node_name, datum.output_slot,
                                              dump_count)
                font_attr_segs[len(lines) - 1] = [
                    (2, len(line), debugger_cli_common.MenuItem(None, command))
                ]
                dump_count += 1

        output = debugger_cli_common.RichTextLines(
            lines, font_attr_segs=font_attr_segs)
        output_with_header = debugger_cli_common.RichTextLines(
            ["%d dumped tensor(s):" % dump_count, ""])
        output_with_header.extend(output)
        return output_with_header
예제 #4
0
  def print_source(self, args, screen_info=None):
    """Print a Python source file with line-level profile information.

    Args:
      args: Command-line arguments, excluding the command prefix, as a list of
        str.
      screen_info: Optional dict input containing screen information such as
        cols.

    Returns:
      Output text lines as a RichTextLines object.
    """
    del screen_info

    parsed = self._arg_parsers["print_source"].parse_args(args)

    device_name_regex = (re.compile(parsed.device_name_filter)
                         if parsed.device_name_filter else None)

    profile_data = []
    data_generator = self._get_profile_data_generator()
    device_count = len(self._run_metadata.step_stats.dev_stats)
    for index in range(device_count):
      device_stats = self._run_metadata.step_stats.dev_stats[index]
      if device_name_regex and not device_name_regex.match(device_stats.device):
        continue
      profile_data.extend(data_generator(device_stats))

    source_annotation = source_utils.annotate_source_against_profile(
        profile_data,
        os.path.expanduser(parsed.source_file_path),
        node_name_filter=parsed.node_name_filter,
        op_type_filter=parsed.op_type_filter)
    if not source_annotation:
      return debugger_cli_common.RichTextLines(
          ["The source file %s does not contain any profile information for "
           "the previous Session run under the following "
           "filters:" % parsed.source_file_path,
           "  --%s: %s" % (_DEVICE_NAME_FILTER_FLAG, parsed.device_name_filter),
           "  --%s: %s" % (_NODE_NAME_FILTER_FLAG, parsed.node_name_filter),
           "  --%s: %s" % (_OP_TYPE_FILTER_FLAG, parsed.op_type_filter)])

    max_total_cost = 0
    for line_index in source_annotation:
      total_cost = self._get_total_cost(source_annotation[line_index],
                                        parsed.cost_type)
      max_total_cost = max(max_total_cost, total_cost)

    source_lines, line_num_width = source_utils.load_source(
        parsed.source_file_path)

    cost_bar_max_length = 10
    total_cost_head = parsed.cost_type
    column_widths = {
        "cost_bar": cost_bar_max_length + 3,
        "total_cost": len(total_cost_head) + 3,
        "num_nodes_execs": len(self._NUM_EXECS_SUB_HEAD) + 1,
        "line_number": line_num_width,
    }

    head = RL(
        " " * column_widths["cost_bar"] +
        total_cost_head +
        " " * (column_widths["total_cost"] - len(total_cost_head)) +
        self._NUM_NODES_HEAD +
        " " * (column_widths["num_nodes_execs"] - len(self._NUM_NODES_HEAD)),
        font_attr=self._LINE_COST_ATTR)
    head += RL(self._LINENO_HEAD, font_attr=self._LINE_NUM_ATTR)
    sub_head = RL(
        " " * (column_widths["cost_bar"] +
               column_widths["total_cost"]) +
        self._NUM_EXECS_SUB_HEAD +
        " " * (column_widths["num_nodes_execs"] -
               len(self._NUM_EXECS_SUB_HEAD)) +
        " " * column_widths["line_number"],
        font_attr=self._LINE_COST_ATTR)
    sub_head += RL(self._SOURCE_HEAD, font_attr="bold")
    lines = [head, sub_head]

    output_annotations = {}
    for i, line in enumerate(source_lines):
      lineno = i + 1
      if lineno in source_annotation:
        annotation = source_annotation[lineno]
        cost_bar = self._render_normalized_cost_bar(
            self._get_total_cost(annotation, parsed.cost_type), max_total_cost,
            cost_bar_max_length)
        annotated_line = cost_bar
        annotated_line += " " * (column_widths["cost_bar"] - len(cost_bar))

        total_cost = RL(cli_shared.time_to_readable_str(
            self._get_total_cost(annotation, parsed.cost_type),
            force_time_unit=parsed.time_unit),
                        font_attr=self._LINE_COST_ATTR)
        total_cost += " " * (column_widths["total_cost"] - len(total_cost))
        annotated_line += total_cost

        file_path_filter = re.escape(parsed.source_file_path) + "$"
        command = "lp --file_path_filter %s --min_lineno %d --max_lineno %d" % (
            file_path_filter, lineno, lineno + 1)
        if parsed.device_name_filter:
          command += " --%s %s" % (_DEVICE_NAME_FILTER_FLAG,
                                   parsed.device_name_filter)
        if parsed.node_name_filter:
          command += " --%s %s" % (_NODE_NAME_FILTER_FLAG,
                                   parsed.node_name_filter)
        if parsed.op_type_filter:
          command += " --%s %s" % (_OP_TYPE_FILTER_FLAG,
                                   parsed.op_type_filter)
        menu_item = debugger_cli_common.MenuItem(None, command)
        num_nodes_execs = RL("%d(%d)" % (annotation.node_count,
                                         annotation.node_exec_count),
                             font_attr=[self._LINE_COST_ATTR, menu_item])
        num_nodes_execs += " " * (
            column_widths["num_nodes_execs"] - len(num_nodes_execs))
        annotated_line += num_nodes_execs
      else:
        annotated_line = RL(
            " " * sum(column_widths[col_name] for col_name in column_widths
                      if col_name != "line_number"))

      line_num_column = RL(" L%d" % (lineno), self._LINE_NUM_ATTR)
      line_num_column += " " * (
          column_widths["line_number"] - len(line_num_column))
      annotated_line += line_num_column
      annotated_line += line
      lines.append(annotated_line)

      if parsed.init_line == lineno:
        output_annotations[
            debugger_cli_common.INIT_SCROLL_POS_KEY] = len(lines) - 1

    return debugger_cli_common.rich_text_lines_from_rich_line_list(
        lines, annotations=output_annotations)
예제 #5
0
  def _get_list_profile_lines(
      self, device_name, device_index, device_count,
      profile_datum_list, sort_by, sort_reverse, time_unit,
      device_name_filter=None, node_name_filter=None, op_type_filter=None,
      screen_cols=80):
    """Get `RichTextLines` object for list_profile command for a given device.

    Args:
      device_name: (string) Device name.
      device_index: (int) Device index.
      device_count: (int) Number of devices.
      profile_datum_list: List of `ProfileDatum` objects.
      sort_by: (string) Identifier of column to sort. Sort identifier
          must match value of SORT_OPS_BY_OP_NAME, SORT_OPS_BY_OP_TYPE,
          SORT_OPS_BY_EXEC_TIME, SORT_OPS_BY_MEMORY or SORT_OPS_BY_LINE.
      sort_reverse: (bool) Whether to sort in descending instead of default
          (ascending) order.
      time_unit: time unit, must be in cli_shared.TIME_UNITS.
      device_name_filter: Regular expression to filter by device name.
      node_name_filter: Regular expression to filter by node name.
      op_type_filter: Regular expression to filter by op type.
      screen_cols: (int) Number of columns available on the screen (i.e.,
        available screen width).

    Returns:
      `RichTextLines` object containing a table that displays profiling
      information for each op.
    """
    profile_data = ProfileDataTableView(profile_datum_list, time_unit=time_unit)

    # Calculate total time early to calculate column widths.
    total_op_time = sum(datum.op_time for datum in profile_datum_list)
    total_exec_time = sum(datum.node_exec_stats.all_end_rel_micros
                          for datum in profile_datum_list)
    device_total_row = [
        "Device Total", "",
        cli_shared.time_to_readable_str(total_op_time,
                                        force_time_unit=time_unit),
        cli_shared.time_to_readable_str(total_exec_time,
                                        force_time_unit=time_unit)]

    # Calculate column widths.
    column_widths = [
        len(column_name) for column_name in profile_data.column_names()]
    for col in range(len(device_total_row)):
      column_widths[col] = max(column_widths[col], len(device_total_row[col]))
    for col in range(len(column_widths)):
      for row in range(profile_data.row_count()):
        column_widths[col] = max(
            column_widths[col], len(profile_data.value(
                row,
                col,
                device_name_filter=device_name_filter,
                node_name_filter=node_name_filter,
                op_type_filter=op_type_filter)))
      column_widths[col] += 2  # add margin between columns

    # Add device name.
    output = [RL("-" * screen_cols)]
    device_row = "Device %d of %d: %s" % (
        device_index + 1, device_count, device_name)
    output.append(RL(device_row))
    output.append(RL())

    # Add headers.
    base_command = "list_profile"
    row = RL()
    for col in range(profile_data.column_count()):
      column_name = profile_data.column_names()[col]
      sort_id = profile_data.column_sort_id(col)
      command = "%s -s %s" % (base_command, sort_id)
      if sort_by == sort_id and not sort_reverse:
        command += " -r"
      head_menu_item = debugger_cli_common.MenuItem(None, command)
      row += RL(column_name, font_attr=[head_menu_item, "bold"])
      row += RL(" " * (column_widths[col] - len(column_name)))

    output.append(row)

    # Add data rows.
    for row in range(profile_data.row_count()):
      new_row = RL()
      for col in range(profile_data.column_count()):
        new_cell = profile_data.value(
            row,
            col,
            device_name_filter=device_name_filter,
            node_name_filter=node_name_filter,
            op_type_filter=op_type_filter)
        new_row += new_cell
        new_row += RL(" " * (column_widths[col] - len(new_cell)))
      output.append(new_row)

    # Add stat totals.
    row_str = ""
    for width, row in zip(column_widths, device_total_row):
      row_str += ("{:<%d}" % width).format(row)
    output.append(RL())
    output.append(RL(row_str))
    return debugger_cli_common.rich_text_lines_from_rich_line_list(output)
예제 #6
0
def get_run_start_intro(run_call_count,
                        fetches,
                        feed_dict,
                        tensor_filters,
                        is_callable_runner=False):
    """Generate formatted intro for run-start UI.

  Args:
    run_call_count: (int) Run call counter.
    fetches: Fetches of the `Session.run()` call. See doc of `Session.run()`
      for more details.
    feed_dict: Feeds to the `Session.run()` call. See doc of `Session.run()`
      for more details.
    tensor_filters: (dict) A dict from tensor-filter name to tensor-filter
      callable.
    is_callable_runner: (bool) whether a runner returned by
        Session.make_callable is being run.

  Returns:
    (RichTextLines) Formatted intro message about the `Session.run()` call.
  """

    fetch_lines = _get_fetch_names(fetches)

    if not feed_dict:
        feed_dict_lines = [debugger_cli_common.RichLine("  (Empty)")]
    else:
        feed_dict_lines = []
        for feed_key in feed_dict:
            feed_key_name = get_graph_element_name(feed_key)
            feed_dict_line = debugger_cli_common.RichLine("  ")
            feed_dict_line += debugger_cli_common.RichLine(
                feed_key_name,
                debugger_cli_common.MenuItem(None, "pf '%s'" % feed_key_name))
            # Surround the name string with quotes, because feed_key_name may contain
            # spaces in some cases, e.g., SparseTensors.
            feed_dict_lines.append(feed_dict_line)
    feed_dict_lines = debugger_cli_common.rich_text_lines_from_rich_line_list(
        feed_dict_lines)

    out = debugger_cli_common.RichTextLines(_HORIZONTAL_BAR)
    if is_callable_runner:
        out.append("Running a runner returned by Session.make_callable()")
    else:
        out.append("Session.run() call #%d:" % run_call_count)
        out.append("")
        out.append("Fetch(es):")
        out.extend(
            debugger_cli_common.RichTextLines(
                ["  " + line for line in fetch_lines]))
        out.append("")
        out.append("Feed dict:")
        out.extend(feed_dict_lines)
    out.append(_HORIZONTAL_BAR)
    out.append("")
    out.append("Select one of the following commands to proceed ---->")

    out.extend(
        _recommend_command("run",
                           "Execute the run() call with debug tensor-watching",
                           create_link=True))
    out.extend(
        _recommend_command(
            "run -n",
            "Execute the run() call without debug tensor-watching",
            create_link=True))
    out.extend(
        _recommend_command(
            "run -t <T>",
            "Execute run() calls (T - 1) times without debugging, then "
            "execute run() once more with debugging and drop back to the CLI"))
    out.extend(
        _recommend_command(
            "run -f <filter_name>",
            "Keep executing run() calls until a dumped tensor passes a given, "
            "registered filter (conditional breakpoint mode)"))

    more_lines = ["    Registered filter(s):"]
    if tensor_filters:
        filter_names = []
        for filter_name in tensor_filters:
            filter_names.append(filter_name)
            command_menu_node = debugger_cli_common.MenuItem(
                "", "run -f %s" % filter_name)
            more_lines.append(
                RL("        * ") + RL(filter_name, command_menu_node))
    else:
        more_lines.append("        (None)")

    out.extend(
        debugger_cli_common.rich_text_lines_from_rich_line_list(more_lines))

    out.extend(
        _recommend_command(
            "invoke_stepper",
            "Use the node-stepper interface, which allows you to interactively "
            "step through nodes involved in the graph run() call and "
            "inspect/modify their values",
            create_link=True))

    out.append("")

    out.append_rich_line(
        RL("For more details, see ") +
        RL("help.", debugger_cli_common.MenuItem("", "help")) + ".")
    out.append("")

    # Make main menu for the run-start intro.
    menu = debugger_cli_common.Menu()
    menu.append(debugger_cli_common.MenuItem("run", "run"))
    menu.append(
        debugger_cli_common.MenuItem("invoke_stepper", "invoke_stepper"))
    menu.append(debugger_cli_common.MenuItem("exit", "exit"))
    out.annotations[debugger_cli_common.MAIN_MENU_KEY] = menu

    return out
  def testCommandTypeConstructorSucceeds(self):
    menu_node = debugger_cli_common.MenuItem("water flower", "water_flower")

    self.assertEqual("water flower", menu_node.caption)
    self.assertEqual("water_flower", menu_node.content)
예제 #8
0
def get_run_start_intro(run_call_count,
                        fetches,
                        feed_dict,
                        tensor_filters):
  """Generate formatted intro for run-start UI.

  Args:
    run_call_count: (int) Run call counter.
    fetches: Fetches of the `Session.run()` call. See doc of `Session.run()`
      for more details.
    feed_dict: Feeds to the `Session.run()` call. See doc of `Session.run()`
      for more details.
    tensor_filters: (dict) A dict from tensor-filter name to tensor-filter
      callable.

  Returns:
    (RichTextLines) Formatted intro message about the `Session.run()` call.
  """

  fetch_lines = _get_fetch_names(fetches)

  if not feed_dict:
    feed_dict_lines = ["(Empty)"]
  else:
    feed_dict_lines = []
    for feed_key in feed_dict:
      if isinstance(feed_key, six.string_types):
        feed_dict_lines.append(feed_key)
      else:
        feed_dict_lines.append(feed_key.name)

  intro_lines = [
      "======================================",
      "Session.run() call #%d:" % run_call_count,
      "", "Fetch(es):"
  ]
  intro_lines.extend(["  " + line for line in fetch_lines])
  intro_lines.extend(["", "Feed dict(s):"])
  intro_lines.extend(["  " + line for line in feed_dict_lines])
  intro_lines.extend([
      "======================================", "",
      "Select one of the following commands to proceed ---->"
  ])

  out = debugger_cli_common.RichTextLines(intro_lines)

  out.extend(
      _recommend_command(
          "run",
          "Execute the run() call with debug tensor-watching",
          create_link=True))
  out.extend(
      _recommend_command(
          "run -n",
          "Execute the run() call without debug tensor-watching",
          create_link=True))
  out.extend(
      _recommend_command(
          "run -t <T>",
          "Execute run() calls (T - 1) times without debugging, then "
          "execute run() once more with debugging and drop back to the CLI"))
  out.extend(
      _recommend_command(
          "run -f <filter_name>",
          "Keep executing run() calls until a dumped tensor passes a given, "
          "registered filter (conditional breakpoint mode)"))

  more_lines = ["    Registered filter(s):"]
  if tensor_filters:
    filter_names = []
    for filter_name in tensor_filters:
      filter_names.append(filter_name)
      command_menu_node = debugger_cli_common.MenuItem(
          "", "run -f %s" % filter_name)
      more_lines.append(RL("        * ") + RL(filter_name, command_menu_node))
  else:
    more_lines.append("        (None)")

  out.extend(
      debugger_cli_common.rich_text_lines_from_rich_line_list(more_lines))

  out.extend(
      _recommend_command(
          "invoke_stepper",
          "Use the node-stepper interface, which allows you to interactively "
          "step through nodes involved in the graph run() call and "
          "inspect/modify their values", create_link=True))

  out.append("")

  out.append_rich_line(RL("For more details, see ") +
                       RL("help.", debugger_cli_common.MenuItem("", "help")) +
                       ".")
  out.append("")

  # Make main menu for the run-start intro.
  menu = debugger_cli_common.Menu()
  menu.append(debugger_cli_common.MenuItem("run", "run"))
  menu.append(debugger_cli_common.MenuItem(
      "invoke_stepper", "invoke_stepper"))
  menu.append(debugger_cli_common.MenuItem("exit", "exit"))
  out.annotations[debugger_cli_common.MAIN_MENU_KEY] = menu

  return out
예제 #9
0
def _add_main_menu(output,
                   node_name=None,
                   enable_list_tensors=True,
                   enable_node_info=True,
                   enable_print_tensor=True,
                   enable_list_inputs=True,
                   enable_list_outputs=True):
    """Generate main menu for the screen output from a command.

  Args:
    output: (debugger_cli_common.RichTextLines) the output object to modify.
    node_name: (str or None) name of the node involved (if any). If None,
      the menu items node_info, list_inputs and list_outputs will be
      automatically disabled, overriding the values of arguments
      enable_node_info, enable_list_inputs and enable_list_outputs.
    enable_list_tensors: (bool) whether the list_tensor menu item will be
      enabled.
    enable_node_info: (bool) whether the node_info item will be enabled.
    enable_print_tensor: (bool) whether the print_tensor item will be enabled.
    enable_list_inputs: (bool) whether the item list_inputs will be enabled.
    enable_list_outputs: (bool) whether the item list_outputs will be enabled.
  """

    menu = debugger_cli_common.Menu()

    menu.append(
        debugger_cli_common.MenuItem("list_tensors",
                                     "list_tensors",
                                     enabled=enable_list_tensors))

    if node_name:
        menu.append(
            debugger_cli_common.MenuItem("node_info",
                                         "node_info -a -d %s" % node_name,
                                         enabled=enable_node_info))
        menu.append(
            debugger_cli_common.MenuItem("print_tensor",
                                         "print_tensor %s" % node_name,
                                         enabled=enable_print_tensor))
        menu.append(
            debugger_cli_common.MenuItem("list_inputs",
                                         "list_inputs -c -r %s" % node_name,
                                         enabled=enable_list_inputs))
        menu.append(
            debugger_cli_common.MenuItem("list_outputs",
                                         "list_outputs -c -r %s" % node_name,
                                         enabled=enable_list_outputs))
    else:
        menu.append(
            debugger_cli_common.MenuItem("node_info", None, enabled=False))
        menu.append(
            debugger_cli_common.MenuItem("print_tensor", None, enabled=False))
        menu.append(
            debugger_cli_common.MenuItem("list_inputs", None, enabled=False))
        menu.append(
            debugger_cli_common.MenuItem("list_outputs", None, enabled=False))

    menu.append(debugger_cli_common.MenuItem("run_info", "run_info"))
    menu.append(debugger_cli_common.MenuItem("help", "help"))

    output.annotations[debugger_cli_common.MAIN_MENU_KEY] = menu
예제 #10
0
    def list_tensors(self, args, screen_info=None):
        """Command handler for list_tensors.

    List tensors dumped during debugged Session.run() call.

    Args:
      args: Command-line arguments, excluding the command prefix, as a list of
        str.
      screen_info: Optional dict input containing screen information such as
        cols.

    Returns:
      Output text lines as a RichTextLines object.
    """

        # TODO(cais): Add annotations of substrings for dumped tensor names, to
        # facilitate on-screen highlighting/selection of node names.
        _ = screen_info

        parsed = self._arg_parsers["list_tensors"].parse_args(args)

        output = []
        font_attr_segs = {}

        filter_strs = []
        if parsed.op_type_filter:
            op_type_regex = re.compile(parsed.op_type_filter)
            filter_strs.append("Op type regex filter: \"%s\"" %
                               parsed.op_type_filter)
        else:
            op_type_regex = None

        if parsed.node_name_filter:
            node_name_regex = re.compile(parsed.node_name_filter)
            filter_strs.append("Node name regex filter: \"%s\"" %
                               parsed.node_name_filter)
        else:
            node_name_regex = None

        filter_output = debugger_cli_common.RichTextLines(filter_strs)

        if parsed.tensor_filter:
            try:
                filter_callable = self.get_tensor_filter(parsed.tensor_filter)
            except ValueError:
                output = cli_shared.error(
                    "There is no tensor filter named \"%s\"." %
                    parsed.tensor_filter)
                _add_main_menu(output,
                               node_name=None,
                               enable_list_tensors=False)
                return output

            data_to_show = self._debug_dump.find(filter_callable)
        else:
            data_to_show = self._debug_dump.dumped_tensor_data

        # TODO(cais): Implement filter by lambda on tensor value.

        dump_count = 0
        for dump in data_to_show:
            if node_name_regex and not node_name_regex.match(dump.node_name):
                continue

            if op_type_regex:
                op_type = self._debug_dump.node_op_type(dump.node_name)
                if not op_type_regex.match(op_type):
                    continue

            rel_time = (dump.timestamp - self._debug_dump.t0) / 1000.0
            dumped_tensor_name = "%s:%d" % (dump.node_name, dump.output_slot)
            output.append("[%.3f ms] %s" % (rel_time, dumped_tensor_name))
            font_attr_segs[len(output) - 1] = [
                (len(output[-1]) - len(dumped_tensor_name), len(output[-1]),
                 debugger_cli_common.MenuItem("",
                                              "pt %s" % dumped_tensor_name))
            ]
            dump_count += 1

        filter_output.append("")
        filter_output.extend(
            debugger_cli_common.RichTextLines(output,
                                              font_attr_segs=font_attr_segs))
        output = filter_output

        if parsed.tensor_filter:
            output.prepend([
                "%d dumped tensor(s) passing filter \"%s\":" %
                (dump_count, parsed.tensor_filter)
            ])
        else:
            output.prepend(["%d dumped tensor(s):" % dump_count])

        _add_main_menu(output, node_name=None, enable_list_tensors=False)
        return output
예제 #11
0
    def print_tensor(self, args, screen_info=None):
        """Command handler for print_tensor.

    Print value of a given dumped tensor.

    Args:
      args: Command-line arguments, excluding the command prefix, as a list of
        str.
      screen_info: Optional dict input containing screen information such as
        cols.

    Returns:
      Output text lines as a RichTextLines object.
    """

        parsed = self._arg_parsers["print_tensor"].parse_args(args)

        if screen_info and "cols" in screen_info:
            np_printoptions = {"linewidth": screen_info["cols"]}
        else:
            np_printoptions = {}

        # Determine if any range-highlighting is required.
        highlight_options = cli_shared.parse_ranges_highlight(parsed.ranges)

        tensor_name, tensor_slicing = (
            command_parser.parse_tensor_name_with_slicing(parsed.tensor_name))

        node_name, output_slot = debug_data.parse_node_or_tensor_name(
            tensor_name)
        if (self._debug_dump.loaded_partition_graphs()
                and not self._debug_dump.node_exists(node_name)):
            output = cli_shared.error(
                "Node \"%s\" does not exist in partition graphs" % node_name)
            _add_main_menu(output,
                           node_name=None,
                           enable_list_tensors=True,
                           enable_print_tensor=False)
            return output

        watch_keys = self._debug_dump.debug_watch_keys(node_name)
        if output_slot is None:
            output_slots = set()
            for watch_key in watch_keys:
                output_slots.add(int(watch_key.split(":")[1]))

            if len(output_slots) == 1:
                # There is only one dumped tensor from this node, so there is no
                # ambiguity. Proceed to show the only dumped tensor.
                output_slot = list(output_slots)[0]
            else:
                # There are more than one dumped tensors from this node. Indicate as
                # such.
                # TODO(cais): Provide an output screen with command links for
                # convenience.
                lines = [
                    "Node \"%s\" generated debug dumps from %s output slots:" %
                    (node_name, len(output_slots)),
                    "Please specify the output slot: %s:x." % node_name
                ]
                output = debugger_cli_common.RichTextLines(lines)
                _add_main_menu(output,
                               node_name=node_name,
                               enable_list_tensors=True,
                               enable_print_tensor=False)
                return output

        # Find debug dump data that match the tensor name (node name + output
        # slot).
        matching_data = []
        for watch_key in watch_keys:
            debug_tensor_data = self._debug_dump.watch_key_to_data(watch_key)
            for datum in debug_tensor_data:
                if datum.output_slot == output_slot:
                    matching_data.append(datum)

        if not matching_data:
            # No dump for this tensor.
            output = cli_shared.error(
                "Tensor \"%s\" did not generate any dumps." %
                parsed.tensor_name)
        elif len(matching_data) == 1:
            # There is only one dump for this tensor.
            if parsed.number <= 0:
                output = cli_shared.format_tensor(
                    matching_data[0].get_tensor(),
                    matching_data[0].watch_key,
                    np_printoptions,
                    print_all=parsed.print_all,
                    tensor_slicing=tensor_slicing,
                    highlight_options=highlight_options)
            else:
                output = cli_shared.error(
                    "Invalid number (%d) for tensor %s, which generated one dump."
                    % (parsed.number, parsed.tensor_name))

            _add_main_menu(output,
                           node_name=node_name,
                           enable_print_tensor=False)
        else:
            # There are more than one dumps for this tensor.
            if parsed.number < 0:
                lines = [
                    "Tensor \"%s\" generated %d dumps:" %
                    (parsed.tensor_name, len(matching_data))
                ]
                font_attr_segs = {}

                for i, datum in enumerate(matching_data):
                    rel_time = (datum.timestamp - self._debug_dump.t0) / 1000.0
                    lines.append("#%d [%.3f ms] %s" %
                                 (i, rel_time, datum.watch_key))
                    command = "print_tensor %s -n %d" % (parsed.tensor_name, i)
                    font_attr_segs[len(lines) - 1] = [
                        (len(lines[-1]) - len(datum.watch_key), len(lines[-1]),
                         debugger_cli_common.MenuItem(None, command))
                    ]

                lines.append("")
                lines.append(
                    "You can use the -n (--number) flag to specify which dump to "
                    "print.")
                lines.append("For example:")
                lines.append("  print_tensor %s -n 0" % parsed.tensor_name)

                output = debugger_cli_common.RichTextLines(
                    lines, font_attr_segs=font_attr_segs)
            elif parsed.number >= len(matching_data):
                output = cli_shared.error(
                    "Specified number (%d) exceeds the number of available dumps "
                    "(%d) for tensor %s" %
                    (parsed.number, len(matching_data), parsed.tensor_name))
            else:
                output = cli_shared.format_tensor(
                    matching_data[parsed.number].get_tensor(),
                    matching_data[parsed.number].watch_key +
                    " (dump #%d)" % parsed.number,
                    np_printoptions,
                    print_all=parsed.print_all,
                    tensor_slicing=tensor_slicing,
                    highlight_options=highlight_options)
            _add_main_menu(output,
                           node_name=node_name,
                           enable_print_tensor=False)

        return output
예제 #12
0
    def _dfs_from_node(self,
                       lines,
                       attr_segs,
                       node_name,
                       tracker,
                       max_depth,
                       depth,
                       unfinished,
                       include_control=False,
                       show_op_type=False,
                       command_template=None):
        """Perform depth-first search (DFS) traversal of a node's input tree.

    It recursively tracks the inputs (or output recipients) of the node called
    node_name, and append these inputs (or output recipients) to a list of text
    lines (lines) with proper indentation that reflects the recursion depth,
    together with some formatting attributes (to attr_segs). The formatting
    attributes can include command shortcuts, for example.

    Args:
      lines: Text lines to append to, as a list of str.
      attr_segs: (dict) Attribute segments dictionary to append to.
      node_name: Name of the node, as a str. This arg is updated during the
        recursion.
      tracker: A callable that takes one str as the node name input and
        returns a list of str as the inputs/outputs.
        This makes it this function general enough to be used with both
        node-input and node-output tracking.
      max_depth: Maximum recursion depth, as an int.
      depth: Current recursion depth. This arg is updated during the
        recursion.
      unfinished: A stack of unfinished recursion depths, as a list of int.
      include_control: Whether control dependencies are to be included as
        inputs (and marked as such).
      show_op_type: Whether op type of the input nodes are to be displayed
        alongside the nodes' names.
      command_template: (str) Template for command shortcut of the node names.
    """

        # Make a shallow copy of the list because it may be extended later.
        all_inputs = copy.copy(tracker(node_name, is_control=False))
        is_ctrl = [False] * len(all_inputs)
        if include_control:
            # Sort control inputs or recipients in in alphabetical order of the node
            # names.
            ctrl_inputs = sorted(tracker(node_name, is_control=True))
            all_inputs.extend(ctrl_inputs)
            is_ctrl.extend([True] * len(ctrl_inputs))

        if not all_inputs:
            if depth == 1:
                lines.append("  [None]")

            return

        unfinished.append(depth)

        # Create depth-dependent hanging indent for the line.
        hang = ""
        for k in xrange(depth):
            if k < depth - 1:
                if k + 1 in unfinished:
                    hang += HANG_UNFINISHED
                else:
                    hang += HANG_FINISHED
            else:
                hang += HANG_SUFFIX

        if all_inputs and depth > max_depth:
            lines.append(hang + ELLIPSIS)
            unfinished.pop()
            return

        hang += DEPTH_TEMPLATE % depth

        for i in xrange(len(all_inputs)):
            inp = all_inputs[i]
            if is_ctrl[i]:
                ctrl_str = CTRL_LABEL
            else:
                ctrl_str = ""

            op_type_str = ""
            if show_op_type:
                op_type_str = OP_TYPE_TEMPLATE % self._debug_dump.node_op_type(
                    inp)

            if i == len(all_inputs) - 1:
                unfinished.pop()

            line = hang + ctrl_str + op_type_str + inp
            lines.append(line)
            if command_template:
                attr_segs[len(lines) - 1] = [
                    (len(line) - len(inp), len(line),
                     debugger_cli_common.MenuItem(None,
                                                  command_template % inp))
                ]

            # Recursive call.
            # The input's/output's name can be a tensor name, in the case of node
            # with >1 output slots.
            inp_node_name, _ = debug_data.parse_node_or_tensor_name(inp)
            self._dfs_from_node(lines,
                                attr_segs,
                                inp_node_name,
                                tracker,
                                max_depth,
                                depth + 1,
                                unfinished,
                                include_control=include_control,
                                show_op_type=show_op_type,
                                command_template=command_template)
예제 #13
0
    def _get_list_profile_lines(self, device_name, device_index, device_count,
                                profile_datum_list, sort_by, sort_reverse):
        """Get `RichTextLines` object for list_profile command for a given device.

    Args:
      device_name: (string) Device name.
      device_index: (int) Device index.
      device_count: (int) Number of devices.
      profile_datum_list: List of `ProfileDatum` objects.
      sort_by: (string) Identifier of column to sort. Sort identifier
          must match value of SORT_OPS_BY_OP_NAME, SORT_OPS_BY_EXEC_TIME,
          SORT_OPS_BY_MEMORY or SORT_OPS_BY_LINE.
      sort_reverse: (bool) Whether to sort in descending instead of default
          (ascending) order.

    Returns:
      `RichTextLines` object containing a table that displays profiling
      information for each op.
    """
        profile_data = ProfileDataTableView(profile_datum_list)

        # Calculate total time early to calculate column widths.
        total_op_time = sum(datum.op_time for datum in profile_datum_list)
        total_exec_time = sum(datum.node_exec_stats.all_end_rel_micros
                              for datum in profile_datum_list)
        device_total_row = [
            "Device Total",
            cli_shared.time_to_readable_str(total_op_time),
            cli_shared.time_to_readable_str(total_exec_time)
        ]

        # Calculate column widths.
        column_widths = [
            len(column_name) for column_name in profile_data.column_names()
        ]
        for col in range(len(device_total_row)):
            column_widths[col] = max(column_widths[col],
                                     len(device_total_row[col]))
        for col in range(len(column_widths)):
            for row in range(profile_data.row_count()):
                column_widths[col] = max(
                    column_widths[col], len(str(profile_data.value(row, col))))
            column_widths[col] += 2  # add margin between columns

        # Add device name.
        output = debugger_cli_common.RichTextLines(["-" * 80])
        device_row = "Device %d of %d: %s" % (device_index + 1, device_count,
                                              device_name)
        output.extend(debugger_cli_common.RichTextLines([device_row, ""]))

        # Add headers.
        base_command = "list_profile"
        attr_segs = {0: []}
        row = ""
        for col in range(profile_data.column_count()):
            column_name = profile_data.column_names()[col]
            sort_id = profile_data.column_sort_id(col)
            command = "%s -s %s" % (base_command, sort_id)
            if sort_by == sort_id and not sort_reverse:
                command += " -r"
            curr_row = ("{:<%d}" % column_widths[col]).format(column_name)
            prev_len = len(row)
            row += curr_row
            attr_segs[0].append(
                (prev_len, prev_len + len(column_name),
                 [debugger_cli_common.MenuItem(None, command), "bold"]))

        output.extend(
            debugger_cli_common.RichTextLines([row], font_attr_segs=attr_segs))

        # Add data rows.
        for row in range(profile_data.row_count()):
            row_str = ""
            for col in range(profile_data.column_count()):
                row_str += ("{:<%d}" % column_widths[col]).format(
                    profile_data.value(row, col))
            output.extend(debugger_cli_common.RichTextLines([row_str]))

        # Add stat totals.
        row_str = ""
        for col in range(len(device_total_row)):
            row_str += ("{:<%d}" % column_widths[col]).format(
                device_total_row[col])
        output.extend(debugger_cli_common.RichTextLines(""))
        output.extend(debugger_cli_common.RichTextLines(row_str))
        return output