Example #1
0
  def __init__(self, profile_datum_list, time_unit=cli_shared.TIME_UNIT_US):
    """Constructor.

    Args:
      profile_datum_list: List of `ProfileDatum` objects.
      time_unit: must be in cli_shared.TIME_UNITS.
    """
    self._profile_datum_list = profile_datum_list
    self.formatted_start_time = [
        datum.start_time for datum in profile_datum_list]
    self.formatted_op_time = [
        cli_shared.time_to_readable_str(datum.op_time,
                                        force_time_unit=time_unit)
        for datum in profile_datum_list]
    self.formatted_exec_time = [
        cli_shared.time_to_readable_str(
            datum.node_exec_stats.all_end_rel_micros,
            force_time_unit=time_unit)
        for datum in profile_datum_list]

    self._column_names = ["Node",
                          "Op Type",
                          "Start Time (us)",
                          "Op Time (%s)" % time_unit,
                          "Exec Time (%s)" % time_unit,
                          "Filename:Lineno(function)"]
    self._column_sort_ids = [SORT_OPS_BY_OP_NAME, SORT_OPS_BY_OP_TYPE,
                             SORT_OPS_BY_START_TIME, SORT_OPS_BY_OP_TIME,
                             SORT_OPS_BY_EXEC_TIME, SORT_OPS_BY_LINE]
  def __init__(self, profile_datum_list, time_unit=cli_shared.TIME_UNIT_US):
    """Constructor.

    Args:
      profile_datum_list: List of `ProfileDatum` objects.
      time_unit: must be in cli_shared.TIME_UNITS.
    """
    self._profile_datum_list = profile_datum_list
    self.formatted_start_time = [
        datum.start_time for datum in profile_datum_list]
    self.formatted_op_time = [
        cli_shared.time_to_readable_str(datum.op_time,
                                        force_time_unit=time_unit)
        for datum in profile_datum_list]
    self.formatted_exec_time = [
        cli_shared.time_to_readable_str(
            datum.node_exec_stats.all_end_rel_micros,
            force_time_unit=time_unit)
        for datum in profile_datum_list]

    self._column_names = ["Node",
                          "Op Type",
                          "Start Time (us)",
                          "Op Time (%s)" % time_unit,
                          "Exec Time (%s)" % time_unit,
                          "Filename:Lineno(function)"]
    self._column_sort_ids = [SORT_OPS_BY_OP_NAME, SORT_OPS_BY_OP_TYPE,
                             SORT_OPS_BY_START_TIME, SORT_OPS_BY_OP_TIME,
                             SORT_OPS_BY_EXEC_TIME, SORT_OPS_BY_LINE]
  def __init__(self, profile_datum_list):
    """Constructor.

    Args:
      profile_datum_list: List of `ProfileDatum` objects.
    """
    self._profile_datum_list = profile_datum_list
    self.formatted_op_time = [
        cli_shared.time_to_readable_str(datum.op_time)
        for datum in profile_datum_list]
    self.formatted_exec_time = [
        cli_shared.time_to_readable_str(
            datum.node_exec_stats.all_end_rel_micros)
        for datum in profile_datum_list]
    self._column_sort_ids = [SORT_OPS_BY_OP_NAME, SORT_OPS_BY_OP_TIME,
                             SORT_OPS_BY_EXEC_TIME, SORT_OPS_BY_LINE]
  def __init__(self, profile_datum_list):
    """Constructor.

    Args:
      profile_datum_list: List of `ProfileDatum` objects.
    """
    self._profile_datum_list = profile_datum_list
    self.formatted_op_time = [
        cli_shared.time_to_readable_str(datum.op_time)
        for datum in profile_datum_list]
    self.formatted_exec_time = [
        cli_shared.time_to_readable_str(
            datum.node_exec_stats.all_end_rel_micros)
        for datum in profile_datum_list]
    self._column_sort_ids = [SORT_OPS_BY_OP_NAME, SORT_OPS_BY_OP_TIME,
                             SORT_OPS_BY_EXEC_TIME, SORT_OPS_BY_LINE]
Example #5
0
  def testForceTimeUnit(self):
    self.assertEqual("40s",
                     cli_shared.time_to_readable_str(
                         40e6, force_time_unit=cli_shared.TIME_UNIT_S))
    self.assertEqual("40000ms",
                     cli_shared.time_to_readable_str(
                         40e6, force_time_unit=cli_shared.TIME_UNIT_MS))
    self.assertEqual("40000000us",
                     cli_shared.time_to_readable_str(
                         40e6, force_time_unit=cli_shared.TIME_UNIT_US))
    self.assertEqual("4e-05s",
                     cli_shared.time_to_readable_str(
                         40, force_time_unit=cli_shared.TIME_UNIT_S))
    self.assertEqual("0",
                     cli_shared.time_to_readable_str(
                         0, force_time_unit=cli_shared.TIME_UNIT_S))

    with self.assertRaisesRegexp(ValueError, r"Invalid time unit: ks"):
      cli_shared.time_to_readable_str(100, force_time_unit="ks")
  def testForceTimeUnit(self):
    self.assertEqual("40s",
                     cli_shared.time_to_readable_str(
                         40e6, force_time_unit=cli_shared.TIME_UNIT_S))
    self.assertEqual("40000ms",
                     cli_shared.time_to_readable_str(
                         40e6, force_time_unit=cli_shared.TIME_UNIT_MS))
    self.assertEqual("40000000us",
                     cli_shared.time_to_readable_str(
                         40e6, force_time_unit=cli_shared.TIME_UNIT_US))
    self.assertEqual("4e-05s",
                     cli_shared.time_to_readable_str(
                         40, force_time_unit=cli_shared.TIME_UNIT_S))
    self.assertEqual("0",
                     cli_shared.time_to_readable_str(
                         0, force_time_unit=cli_shared.TIME_UNIT_S))

    with self.assertRaisesRegexp(ValueError, r"Invalid time unit: ks"):
      cli_shared.time_to_readable_str(100, force_time_unit="ks")
Example #7
0
 def testMillisecondTime(self):
     self.assertEqual("40ms", cli_shared.time_to_readable_str(40e3))
Example #8
0
 def testMicrosecondsTime(self):
     self.assertEqual("40us", cli_shared.time_to_readable_str(40))
Example #9
0
 def testNoneTimeWorks(self):
     self.assertEqual("0", cli_shared.time_to_readable_str(None))
Example #10
0
 def testMicrosecondsTime(self):
   self.assertEqual("40us", cli_shared.time_to_readable_str(40))
Example #11
0
 def testNoneTimeWorks(self):
   self.assertEqual("0", cli_shared.time_to_readable_str(None))
  def print_source(self, args, screen_info=None):
    """Print a Python source file with line-level profile information.

    Args:
      args: Command-line arguments, excluding the command prefix, as a list of
        str.
      screen_info: Optional dict input containing screen information such as
        cols.

    Returns:
      Output text lines as a RichTextLines object.
    """
    del screen_info

    parsed = self._arg_parsers["print_source"].parse_args(args)

    device_name_regex = (re.compile(parsed.device_name_filter)
                         if parsed.device_name_filter else None)

    profile_data = []
    data_generator = self._get_profile_data_generator()
    device_count = len(self._run_metadata.step_stats.dev_stats)
    for index in range(device_count):
      device_stats = self._run_metadata.step_stats.dev_stats[index]
      if device_name_regex and not device_name_regex.match(device_stats.device):
        continue
      profile_data.extend([datum for datum in data_generator(device_stats)])

    source_annotation = source_utils.annotate_source_against_profile(
        profile_data,
        os.path.expanduser(parsed.source_file_path),
        node_name_filter=parsed.node_name_filter,
        op_type_filter=parsed.op_type_filter)
    if not source_annotation:
      return debugger_cli_common.RichTextLines(
          ["The source file %s does not contain any profile information for "
           "the previous Session run under the following "
           "filters:" % parsed.source_file_path,
           "  --%s: %s" % (_DEVICE_NAME_FILTER_FLAG, parsed.device_name_filter),
           "  --%s: %s" % (_NODE_NAME_FILTER_FLAG, parsed.node_name_filter),
           "  --%s: %s" % (_OP_TYPE_FILTER_FLAG, parsed.op_type_filter)])

    max_total_cost = 0
    for line_index in source_annotation:
      total_cost = self._get_total_cost(source_annotation[line_index],
                                        parsed.cost_type)
      max_total_cost = max(max_total_cost, total_cost)

    source_lines, line_num_width = source_utils.load_source(
        parsed.source_file_path)

    cost_bar_max_length = 10
    total_cost_head = parsed.cost_type
    column_widths = {
        "cost_bar": cost_bar_max_length + 3,
        "total_cost": len(total_cost_head) + 3,
        "num_nodes_execs": len(self._NUM_EXECS_SUB_HEAD) + 1,
        "line_number": line_num_width,
    }

    head = RL(
        " " * column_widths["cost_bar"] +
        total_cost_head +
        " " * (column_widths["total_cost"] - len(total_cost_head)) +
        self._NUM_NODES_HEAD +
        " " * (column_widths["num_nodes_execs"] - len(self._NUM_NODES_HEAD)),
        font_attr=self._LINE_COST_ATTR)
    head += RL(self._LINENO_HEAD, font_attr=self._LINE_NUM_ATTR)
    sub_head = RL(
        " " * (column_widths["cost_bar"] +
               column_widths["total_cost"]) +
        self._NUM_EXECS_SUB_HEAD +
        " " * (column_widths["num_nodes_execs"] -
               len(self._NUM_EXECS_SUB_HEAD)) +
        " " * column_widths["line_number"],
        font_attr=self._LINE_COST_ATTR)
    sub_head += RL(self._SOURCE_HEAD, font_attr="bold")
    lines = [head, sub_head]

    output_annotations = {}
    for i, line in enumerate(source_lines):
      lineno = i + 1
      if lineno in source_annotation:
        annotation = source_annotation[lineno]
        cost_bar = self._render_normalized_cost_bar(
            self._get_total_cost(annotation, parsed.cost_type), max_total_cost,
            cost_bar_max_length)
        annotated_line = cost_bar
        annotated_line += " " * (column_widths["cost_bar"] - len(cost_bar))

        total_cost = RL(cli_shared.time_to_readable_str(
            self._get_total_cost(annotation, parsed.cost_type),
            force_time_unit=parsed.time_unit),
                        font_attr=self._LINE_COST_ATTR)
        total_cost += " " * (column_widths["total_cost"] - len(total_cost))
        annotated_line += total_cost

        file_path_filter = re.escape(parsed.source_file_path) + "$"
        command = "lp --file_path_filter %s --min_lineno %d --max_lineno %d" % (
            file_path_filter, lineno, lineno + 1)
        if parsed.device_name_filter:
          command += " --%s %s" % (_DEVICE_NAME_FILTER_FLAG,
                                   parsed.device_name_filter)
        if parsed.node_name_filter:
          command += " --%s %s" % (_NODE_NAME_FILTER_FLAG,
                                   parsed.node_name_filter)
        if parsed.op_type_filter:
          command += " --%s %s" % (_OP_TYPE_FILTER_FLAG,
                                   parsed.op_type_filter)
        menu_item = debugger_cli_common.MenuItem(None, command)
        num_nodes_execs = RL("%d(%d)" % (annotation.node_count,
                                         annotation.node_exec_count),
                             font_attr=[self._LINE_COST_ATTR, menu_item])
        num_nodes_execs += " " * (
            column_widths["num_nodes_execs"] - len(num_nodes_execs))
        annotated_line += num_nodes_execs
      else:
        annotated_line = RL(
            " " * sum(column_widths[col_name] for col_name in column_widths
                      if col_name != "line_number"))

      line_num_column = RL(" L%d" % (lineno), self._LINE_NUM_ATTR)
      line_num_column += " " * (
          column_widths["line_number"] - len(line_num_column))
      annotated_line += line_num_column
      annotated_line += line
      lines.append(annotated_line)

      if parsed.init_line == lineno:
        output_annotations[
            debugger_cli_common.INIT_SCROLL_POS_KEY] = len(lines) - 1

    return debugger_cli_common.rich_text_lines_from_rich_line_list(
        lines, annotations=output_annotations)
  def _get_list_profile_lines(
      self, device_name, device_index, device_count,
      profile_datum_list, sort_by, sort_reverse, time_unit,
      device_name_filter=None, node_name_filter=None, op_type_filter=None,
      screen_cols=80):
    """Get `RichTextLines` object for list_profile command for a given device.

    Args:
      device_name: (string) Device name.
      device_index: (int) Device index.
      device_count: (int) Number of devices.
      profile_datum_list: List of `ProfileDatum` objects.
      sort_by: (string) Identifier of column to sort. Sort identifier
          must match value of SORT_OPS_BY_OP_NAME, SORT_OPS_BY_OP_TYPE,
          SORT_OPS_BY_EXEC_TIME, SORT_OPS_BY_MEMORY or SORT_OPS_BY_LINE.
      sort_reverse: (bool) Whether to sort in descending instead of default
          (ascending) order.
      time_unit: time unit, must be in cli_shared.TIME_UNITS.
      device_name_filter: Regular expression to filter by device name.
      node_name_filter: Regular expression to filter by node name.
      op_type_filter: Regular expression to filter by op type.
      screen_cols: (int) Number of columns available on the screen (i.e.,
        available screen width).

    Returns:
      `RichTextLines` object containing a table that displays profiling
      information for each op.
    """
    profile_data = ProfileDataTableView(profile_datum_list, time_unit=time_unit)

    # Calculate total time early to calculate column widths.
    total_op_time = sum(datum.op_time for datum in profile_datum_list)
    total_exec_time = sum(datum.node_exec_stats.all_end_rel_micros
                          for datum in profile_datum_list)
    device_total_row = [
        "Device Total", "",
        cli_shared.time_to_readable_str(total_op_time,
                                        force_time_unit=time_unit),
        cli_shared.time_to_readable_str(total_exec_time,
                                        force_time_unit=time_unit)]

    # Calculate column widths.
    column_widths = [
        len(column_name) for column_name in profile_data.column_names()]
    for col in range(len(device_total_row)):
      column_widths[col] = max(column_widths[col], len(device_total_row[col]))
    for col in range(len(column_widths)):
      for row in range(profile_data.row_count()):
        column_widths[col] = max(
            column_widths[col], len(profile_data.value(
                row,
                col,
                device_name_filter=device_name_filter,
                node_name_filter=node_name_filter,
                op_type_filter=op_type_filter)))
      column_widths[col] += 2  # add margin between columns

    # Add device name.
    output = [RL("-" * screen_cols)]
    device_row = "Device %d of %d: %s" % (
        device_index + 1, device_count, device_name)
    output.append(RL(device_row))
    output.append(RL())

    # Add headers.
    base_command = "list_profile"
    row = RL()
    for col in range(profile_data.column_count()):
      column_name = profile_data.column_names()[col]
      sort_id = profile_data.column_sort_id(col)
      command = "%s -s %s" % (base_command, sort_id)
      if sort_by == sort_id and not sort_reverse:
        command += " -r"
      head_menu_item = debugger_cli_common.MenuItem(None, command)
      row += RL(column_name, font_attr=[head_menu_item, "bold"])
      row += RL(" " * (column_widths[col] - len(column_name)))

    output.append(row)

    # Add data rows.
    for row in range(profile_data.row_count()):
      new_row = RL()
      for col in range(profile_data.column_count()):
        new_cell = profile_data.value(
            row,
            col,
            device_name_filter=device_name_filter,
            node_name_filter=node_name_filter,
            op_type_filter=op_type_filter)
        new_row += new_cell
        new_row += RL(" " * (column_widths[col] - len(new_cell)))
      output.append(new_row)

    # Add stat totals.
    row_str = ""
    for col in range(len(device_total_row)):
      row_str += ("{:<%d}" % column_widths[col]).format(device_total_row[col])
    output.append(RL())
    output.append(RL(row_str))
    return debugger_cli_common.rich_text_lines_from_rich_line_list(output)
  def _get_list_profile_lines(
      self, device_name, device_index, device_count,
      profile_datum_list, sort_by, sort_reverse, time_unit):
    """Get `RichTextLines` object for list_profile command for a given device.

    Args:
      device_name: (string) Device name.
      device_index: (int) Device index.
      device_count: (int) Number of devices.
      profile_datum_list: List of `ProfileDatum` objects.
      sort_by: (string) Identifier of column to sort. Sort identifier
          must match value of SORT_OPS_BY_OP_NAME, SORT_OPS_BY_EXEC_TIME,
          SORT_OPS_BY_MEMORY or SORT_OPS_BY_LINE.
      sort_reverse: (bool) Whether to sort in descending instead of default
          (ascending) order.
      time_unit: time unit, must be in cli_shared.TIME_UNITS.

    Returns:
      `RichTextLines` object containing a table that displays profiling
      information for each op.
    """
    profile_data = ProfileDataTableView(profile_datum_list, time_unit=time_unit)

    # Calculate total time early to calculate column widths.
    total_op_time = sum(datum.op_time for datum in profile_datum_list)
    total_exec_time = sum(datum.node_exec_stats.all_end_rel_micros
                          for datum in profile_datum_list)
    device_total_row = [
        "Device Total", "",
        cli_shared.time_to_readable_str(total_op_time,
                                        force_time_unit=time_unit),
        cli_shared.time_to_readable_str(total_exec_time,
                                        force_time_unit=time_unit)]

    # Calculate column widths.
    column_widths = [
        len(column_name) for column_name in profile_data.column_names()]
    for col in range(len(device_total_row)):
      column_widths[col] = max(column_widths[col], len(device_total_row[col]))
    for col in range(len(column_widths)):
      for row in range(profile_data.row_count()):
        column_widths[col] = max(
            column_widths[col], len(str(profile_data.value(row, col))))
      column_widths[col] += 2  # add margin between columns

    # Add device name.
    output = debugger_cli_common.RichTextLines(["-"*80])
    device_row = "Device %d of %d: %s" % (
        device_index + 1, device_count, device_name)
    output.extend(debugger_cli_common.RichTextLines([device_row, ""]))

    # Add headers.
    base_command = "list_profile"
    attr_segs = {0: []}
    row = ""
    for col in range(profile_data.column_count()):
      column_name = profile_data.column_names()[col]
      sort_id = profile_data.column_sort_id(col)
      command = "%s -s %s" % (base_command, sort_id)
      if sort_by == sort_id and not sort_reverse:
        command += " -r"
      curr_row = ("{:<%d}" % column_widths[col]).format(column_name)
      prev_len = len(row)
      row += curr_row
      attr_segs[0].append(
          (prev_len, prev_len + len(column_name),
           [debugger_cli_common.MenuItem(None, command), "bold"]))

    output.extend(
        debugger_cli_common.RichTextLines([row], font_attr_segs=attr_segs))

    # Add data rows.
    for row in range(profile_data.row_count()):
      row_str = ""
      for col in range(profile_data.column_count()):
        row_str += ("{:<%d}" % column_widths[col]).format(
            profile_data.value(row, col))
      output.extend(debugger_cli_common.RichTextLines([row_str]))

    # Add stat totals.
    row_str = ""
    for col in range(len(device_total_row)):
      row_str += ("{:<%d}" % column_widths[col]).format(device_total_row[col])
    output.extend(debugger_cli_common.RichTextLines(""))
    output.extend(debugger_cli_common.RichTextLines(row_str))
    return output
Example #15
0
    def _get_list_profile_lines(self, device_name, device_index, device_count,
                                profile_datum_list, sort_by, sort_reverse):
        """Get `RichTextLines` object for list_profile command for a given device.

    Args:
      device_name: (string) Device name.
      device_index: (int) Device index.
      device_count: (int) Number of devices.
      profile_datum_list: List of `ProfileDatum` objects.
      sort_by: (string) Identifier of column to sort. Sort identifier
          must match value of SORT_OPS_BY_OP_NAME, SORT_OPS_BY_EXEC_TIME,
          SORT_OPS_BY_MEMORY or SORT_OPS_BY_LINE.
      sort_reverse: (bool) Whether to sort in descending instead of default
          (ascending) order.

    Returns:
      `RichTextLines` object containing a table that displays profiling
      information for each op.
    """
        profile_data = ProfileDataTableView(profile_datum_list)

        # Calculate total time early to calculate column widths.
        total_op_time = sum(datum.op_time for datum in profile_datum_list)
        total_exec_time = sum(datum.node_exec_stats.all_end_rel_micros
                              for datum in profile_datum_list)
        device_total_row = [
            "Device Total",
            cli_shared.time_to_readable_str(total_op_time),
            cli_shared.time_to_readable_str(total_exec_time)
        ]

        # Calculate column widths.
        column_widths = [
            len(column_name) for column_name in profile_data.column_names()
        ]
        for col in range(len(device_total_row)):
            column_widths[col] = max(column_widths[col],
                                     len(device_total_row[col]))
        for col in range(len(column_widths)):
            for row in range(profile_data.row_count()):
                column_widths[col] = max(
                    column_widths[col], len(str(profile_data.value(row, col))))
            column_widths[col] += 2  # add margin between columns

        # Add device name.
        output = debugger_cli_common.RichTextLines(["-" * 80])
        device_row = "Device %d of %d: %s" % (device_index + 1, device_count,
                                              device_name)
        output.extend(debugger_cli_common.RichTextLines([device_row, ""]))

        # Add headers.
        base_command = "list_profile"
        attr_segs = {0: []}
        row = ""
        for col in range(profile_data.column_count()):
            column_name = profile_data.column_names()[col]
            sort_id = profile_data.column_sort_id(col)
            command = "%s -s %s" % (base_command, sort_id)
            if sort_by == sort_id and not sort_reverse:
                command += " -r"
            curr_row = ("{:<%d}" % column_widths[col]).format(column_name)
            prev_len = len(row)
            row += curr_row
            attr_segs[0].append(
                (prev_len, prev_len + len(column_name),
                 [debugger_cli_common.MenuItem(None, command), "bold"]))

        output.extend(
            debugger_cli_common.RichTextLines([row], font_attr_segs=attr_segs))

        # Add data rows.
        for row in range(profile_data.row_count()):
            row_str = ""
            for col in range(profile_data.column_count()):
                row_str += ("{:<%d}" % column_widths[col]).format(
                    profile_data.value(row, col))
            output.extend(debugger_cli_common.RichTextLines([row_str]))

        # Add stat totals.
        row_str = ""
        for col in range(len(device_total_row)):
            row_str += ("{:<%d}" % column_widths[col]).format(
                device_total_row[col])
        output.extend(debugger_cli_common.RichTextLines(""))
        output.extend(debugger_cli_common.RichTextLines(row_str))
        return output
Example #16
0
 def testSecondTime(self):
     self.assertEqual("40s", cli_shared.time_to_readable_str(40e6))
Example #17
0
  def _get_list_profile_lines(
      self, device_name, device_index, device_count,
      profile_datum_list, sort_by, sort_reverse, time_unit,
      device_name_filter=None, node_name_filter=None, op_type_filter=None,
      screen_cols=80):
    """Get `RichTextLines` object for list_profile command for a given device.

    Args:
      device_name: (string) Device name.
      device_index: (int) Device index.
      device_count: (int) Number of devices.
      profile_datum_list: List of `ProfileDatum` objects.
      sort_by: (string) Identifier of column to sort. Sort identifier
          must match value of SORT_OPS_BY_OP_NAME, SORT_OPS_BY_OP_TYPE,
          SORT_OPS_BY_EXEC_TIME, SORT_OPS_BY_MEMORY or SORT_OPS_BY_LINE.
      sort_reverse: (bool) Whether to sort in descending instead of default
          (ascending) order.
      time_unit: time unit, must be in cli_shared.TIME_UNITS.
      device_name_filter: Regular expression to filter by device name.
      node_name_filter: Regular expression to filter by node name.
      op_type_filter: Regular expression to filter by op type.
      screen_cols: (int) Number of columns available on the screen (i.e.,
        available screen width).

    Returns:
      `RichTextLines` object containing a table that displays profiling
      information for each op.
    """
    profile_data = ProfileDataTableView(profile_datum_list, time_unit=time_unit)

    # Calculate total time early to calculate column widths.
    total_op_time = sum(datum.op_time for datum in profile_datum_list)
    total_exec_time = sum(datum.node_exec_stats.all_end_rel_micros
                          for datum in profile_datum_list)
    device_total_row = [
        "Device Total", "",
        cli_shared.time_to_readable_str(total_op_time,
                                        force_time_unit=time_unit),
        cli_shared.time_to_readable_str(total_exec_time,
                                        force_time_unit=time_unit)]

    # Calculate column widths.
    column_widths = [
        len(column_name) for column_name in profile_data.column_names()]
    for col in range(len(device_total_row)):
      column_widths[col] = max(column_widths[col], len(device_total_row[col]))
    for col in range(len(column_widths)):
      for row in range(profile_data.row_count()):
        column_widths[col] = max(
            column_widths[col], len(profile_data.value(
                row,
                col,
                device_name_filter=device_name_filter,
                node_name_filter=node_name_filter,
                op_type_filter=op_type_filter)))
      column_widths[col] += 2  # add margin between columns

    # Add device name.
    output = [RL("-" * screen_cols)]
    device_row = "Device %d of %d: %s" % (
        device_index + 1, device_count, device_name)
    output.append(RL(device_row))
    output.append(RL())

    # Add headers.
    base_command = "list_profile"
    row = RL()
    for col in range(profile_data.column_count()):
      column_name = profile_data.column_names()[col]
      sort_id = profile_data.column_sort_id(col)
      command = "%s -s %s" % (base_command, sort_id)
      if sort_by == sort_id and not sort_reverse:
        command += " -r"
      head_menu_item = debugger_cli_common.MenuItem(None, command)
      row += RL(column_name, font_attr=[head_menu_item, "bold"])
      row += RL(" " * (column_widths[col] - len(column_name)))

    output.append(row)

    # Add data rows.
    for row in range(profile_data.row_count()):
      new_row = RL()
      for col in range(profile_data.column_count()):
        new_cell = profile_data.value(
            row,
            col,
            device_name_filter=device_name_filter,
            node_name_filter=node_name_filter,
            op_type_filter=op_type_filter)
        new_row += new_cell
        new_row += RL(" " * (column_widths[col] - len(new_cell)))
      output.append(new_row)

    # Add stat totals.
    row_str = ""
    for width, row in zip(column_widths, device_total_row):
      row_str += ("{:<%d}" % width).format(row)
    output.append(RL())
    output.append(RL(row_str))
    return debugger_cli_common.rich_text_lines_from_rich_line_list(output)
Example #18
0
 def testMillisecondTime(self):
   self.assertEqual("40ms", cli_shared.time_to_readable_str(40e3))
Example #19
0
  def print_source(self, args, screen_info=None):
    """Print a Python source file with line-level profile information.

    Args:
      args: Command-line arguments, excluding the command prefix, as a list of
        str.
      screen_info: Optional dict input containing screen information such as
        cols.

    Returns:
      Output text lines as a RichTextLines object.
    """
    del screen_info

    parsed = self._arg_parsers["print_source"].parse_args(args)

    device_name_regex = (re.compile(parsed.device_name_filter)
                         if parsed.device_name_filter else None)

    profile_data = []
    data_generator = self._get_profile_data_generator()
    device_count = len(self._run_metadata.step_stats.dev_stats)
    for index in range(device_count):
      device_stats = self._run_metadata.step_stats.dev_stats[index]
      if device_name_regex and not device_name_regex.match(device_stats.device):
        continue
      profile_data.extend(data_generator(device_stats))

    source_annotation = source_utils.annotate_source_against_profile(
        profile_data,
        os.path.expanduser(parsed.source_file_path),
        node_name_filter=parsed.node_name_filter,
        op_type_filter=parsed.op_type_filter)
    if not source_annotation:
      return debugger_cli_common.RichTextLines(
          ["The source file %s does not contain any profile information for "
           "the previous Session run under the following "
           "filters:" % parsed.source_file_path,
           "  --%s: %s" % (_DEVICE_NAME_FILTER_FLAG, parsed.device_name_filter),
           "  --%s: %s" % (_NODE_NAME_FILTER_FLAG, parsed.node_name_filter),
           "  --%s: %s" % (_OP_TYPE_FILTER_FLAG, parsed.op_type_filter)])

    max_total_cost = 0
    for line_index in source_annotation:
      total_cost = self._get_total_cost(source_annotation[line_index],
                                        parsed.cost_type)
      max_total_cost = max(max_total_cost, total_cost)

    source_lines, line_num_width = source_utils.load_source(
        parsed.source_file_path)

    cost_bar_max_length = 10
    total_cost_head = parsed.cost_type
    column_widths = {
        "cost_bar": cost_bar_max_length + 3,
        "total_cost": len(total_cost_head) + 3,
        "num_nodes_execs": len(self._NUM_EXECS_SUB_HEAD) + 1,
        "line_number": line_num_width,
    }

    head = RL(
        " " * column_widths["cost_bar"] +
        total_cost_head +
        " " * (column_widths["total_cost"] - len(total_cost_head)) +
        self._NUM_NODES_HEAD +
        " " * (column_widths["num_nodes_execs"] - len(self._NUM_NODES_HEAD)),
        font_attr=self._LINE_COST_ATTR)
    head += RL(self._LINENO_HEAD, font_attr=self._LINE_NUM_ATTR)
    sub_head = RL(
        " " * (column_widths["cost_bar"] +
               column_widths["total_cost"]) +
        self._NUM_EXECS_SUB_HEAD +
        " " * (column_widths["num_nodes_execs"] -
               len(self._NUM_EXECS_SUB_HEAD)) +
        " " * column_widths["line_number"],
        font_attr=self._LINE_COST_ATTR)
    sub_head += RL(self._SOURCE_HEAD, font_attr="bold")
    lines = [head, sub_head]

    output_annotations = {}
    for i, line in enumerate(source_lines):
      lineno = i + 1
      if lineno in source_annotation:
        annotation = source_annotation[lineno]
        cost_bar = self._render_normalized_cost_bar(
            self._get_total_cost(annotation, parsed.cost_type), max_total_cost,
            cost_bar_max_length)
        annotated_line = cost_bar
        annotated_line += " " * (column_widths["cost_bar"] - len(cost_bar))

        total_cost = RL(cli_shared.time_to_readable_str(
            self._get_total_cost(annotation, parsed.cost_type),
            force_time_unit=parsed.time_unit),
                        font_attr=self._LINE_COST_ATTR)
        total_cost += " " * (column_widths["total_cost"] - len(total_cost))
        annotated_line += total_cost

        file_path_filter = re.escape(parsed.source_file_path) + "$"
        command = "lp --file_path_filter %s --min_lineno %d --max_lineno %d" % (
            file_path_filter, lineno, lineno + 1)
        if parsed.device_name_filter:
          command += " --%s %s" % (_DEVICE_NAME_FILTER_FLAG,
                                   parsed.device_name_filter)
        if parsed.node_name_filter:
          command += " --%s %s" % (_NODE_NAME_FILTER_FLAG,
                                   parsed.node_name_filter)
        if parsed.op_type_filter:
          command += " --%s %s" % (_OP_TYPE_FILTER_FLAG,
                                   parsed.op_type_filter)
        menu_item = debugger_cli_common.MenuItem(None, command)
        num_nodes_execs = RL("%d(%d)" % (annotation.node_count,
                                         annotation.node_exec_count),
                             font_attr=[self._LINE_COST_ATTR, menu_item])
        num_nodes_execs += " " * (
            column_widths["num_nodes_execs"] - len(num_nodes_execs))
        annotated_line += num_nodes_execs
      else:
        annotated_line = RL(
            " " * sum(column_widths[col_name] for col_name in column_widths
                      if col_name != "line_number"))

      line_num_column = RL(" L%d" % (lineno), self._LINE_NUM_ATTR)
      line_num_column += " " * (
          column_widths["line_number"] - len(line_num_column))
      annotated_line += line_num_column
      annotated_line += line
      lines.append(annotated_line)

      if parsed.init_line == lineno:
        output_annotations[
            debugger_cli_common.INIT_SCROLL_POS_KEY] = len(lines) - 1

    return debugger_cli_common.rich_text_lines_from_rich_line_list(
        lines, annotations=output_annotations)
Example #20
0
 def testSecondTime(self):
   self.assertEqual("40s", cli_shared.time_to_readable_str(40e6))