Beispiel #1
0
    def _register_this_run_info(self, curses_cli):
        curses_cli.register_command_handler(
            "run",
            self._run_handler,
            self._argparsers["run"].format_help(),
            prefix_aliases=["r"])
        # TODO(Pariksheet): Python invoke_stepper implementation not support now.
        #        curses_cli.register_command_handler(
        #            "invoke_stepper",
        #            self._on_run_start_step_handler,
        #            self._argparsers["invoke_stepper"].format_help(),
        #            prefix_aliases=["s"])
        curses_cli.register_command_handler(
            "run_info",
            self._run_info_handler,
            self._argparsers["run_info"].format_help(),
            prefix_aliases=["ri"])
        curses_cli.register_command_handler(
            "print_feed",
            self._print_feed_handler,
            self._argparsers["print_feed"].format_help(),
            prefix_aliases=["pf"])

        if self._tensor_filters:
            # Register tab completion for the filter names.
            curses_cli.register_tab_comp_context(
                ["run", "r"], list(self._tensor_filters.keys()))
        if self._feed_dict:
            # Register tab completion for feed_dict keys.
            feed_keys = [
                common.get_graph_element_name(key)
                for key in self._feed_dict.keys()
            ]
            curses_cli.register_tab_comp_context(["print_feed", "pf"],
                                                 feed_keys)
Beispiel #2
0
    def _print_feed_handler(self, args, screen_info=None):
        np_printoptions = cli_shared.numpy_printoptions_from_screen_info(
            screen_info)

        if not self._feed_dict:
            return cli_shared.error(
                "The feed_dict of the current run is None or empty.")

        parsed = self._argparsers["print_feed"].parse_args(args)
        tensor_name, tensor_slicing = (
            command_parser.parse_tensor_name_with_slicing(parsed.tensor_name))

        feed_key = None
        feed_value = None
        for key in self._feed_dict:
            key_name = common.get_graph_element_name(key)
            if key_name == tensor_name:
                feed_key = key_name
                feed_value = self._feed_dict[key]
                break

        if feed_key is None:
            return cli_shared.error(
                "The feed_dict of the current run does not contain the key %s"
                % tensor_name)
        return cli_shared.format_tensor(
            feed_value,
            feed_key + " (feed)",
            np_printoptions,
            print_all=parsed.print_all,
            tensor_slicing=tensor_slicing,
            highlight_options=cli_shared.parse_ranges_highlight(parsed.ranges),
            include_numeric_summary=parsed.numeric_summary)
def get_run_short_description(run_call_count,
                              fetches,
                              feed_dict,
                              is_callable_runner=False):
    """Get a short description of the run() call.

    Args:
      run_call_count: (int) Run call counter.
      fetches: Fetches of the `GraphRuntime.run()` call. See doc of `GraphRuntime.run()`
        for more details.
      feed_dict: Feeds to the `GraphRuntime.run()` call. See doc of `GraphRuntime.run()`
        for more details.
      is_callable_runner: (bool) whether a runner returned by
          GraphRuntime.make_callable is being run.

    Returns:
      (str) A short description of the run() call, including information about
        the fetche(s) and feed(s).
    """
    if is_callable_runner:
        return "runner from make_callable()"

    description = "run #%d: " % run_call_count

    if ';' not in fetches:
        description += "1 input (%s); " % common.get_graph_element_name(
            fetches)
    else:
        # Could be (nested) list, tuple, dict or namedtuple.
        num_fetches = len(common.get_flattened_names(fetches))
        if num_fetches > 1:
            description += "%d outputs; " % num_fetches
        else:
            description += "%d output; " % num_fetches

    if not feed_dict:
        description += "0 inputs"
    else:
        if len(feed_dict) == 1:
            for key in feed_dict:
                description += "1 input (%s)" % (
                    key if isinstance(key, six.string_types)
                    or not hasattr(key, "name") else key.name)
        else:
            description += "%d inputs" % len(feed_dict)

    return description
def get_run_start_intro(run_call_count,
                        fetches,
                        feed_dict,
                        tensor_filters,
                        is_callable_runner=False):
    """Generate formatted intro for run-start UI.

    Args:
      run_call_count: (int) Run call counter.
      fetches: Fetches of the `GraphRuntime.run()` call. See doc of `GraphRuntime.run()`
        for more details.
      feed_dict: Feeds to the `GraphRuntime.run()` call. See doc of `GraphRuntime.run()`
        for more details.
      tensor_filters: (dict) A dict from tensor-filter name to tensor-filter
        callable.
      is_callable_runner: (bool) whether a runner returned by
          GraphRuntime.make_callable is being run.

    Returns:
      (RichTextLines) Formatted intro message about the `GraphRuntime.run()` call.
    """

    fetch_lines = common.get_flattened_names(fetches)

    if not feed_dict:
        feed_dict_lines = [debugger_cli_common.RichLine("  (Empty)")]
    else:
        feed_dict_lines = []
        for feed_key in feed_dict:
            feed_key_name = common.get_graph_element_name(feed_key)
            feed_dict_line = debugger_cli_common.RichLine("  ")
            feed_dict_line += debugger_cli_common.RichLine(
                feed_key_name,
                debugger_cli_common.MenuItem(None, "pf '%s'" % feed_key_name))
            # Surround the name string with quotes, because feed_key_name may contain
            # spaces in some cases, e.g., SparseTensors.
            feed_dict_lines.append(feed_dict_line)
    feed_dict_lines = debugger_cli_common.rich_text_lines_from_rich_line_list(
        feed_dict_lines)

    out = debugger_cli_common.RichTextLines(_HORIZONTAL_BAR)
    if is_callable_runner:
        out.append(
            " Running a runner returned by GraphRuntime.make_callable()")
    else:
        out.append(" GraphRuntime.run() call #%d:" % run_call_count)
        out.append("")
        out.append(" Output:")
        out.extend(
            debugger_cli_common.RichTextLines(
                ["   " + line for line in fetch_lines]))
        out.append("")
        out.append(" Inputs:")
        out.extend(feed_dict_lines)
    out.append(_HORIZONTAL_BAR)
    out.append("")
    out.append(" Select one of the following commands to proceed ---->")

    out.extend(
        _recommend_command("run",
                           "Execute the run() call with debug tensor-watching",
                           create_link=True))
    out.extend(
        _recommend_command(
            "run -n",
            "Execute the run() call without debug tensor-watching",
            create_link=True))
    out.extend(
        _recommend_command(
            "run -t <T>",
            "Execute run() calls (T - 1) times without debugging, then "
            "execute run() once more with debugging and drop back to the CLI"))
    out.extend(
        _recommend_command(
            "run -f <filter_name>",
            "Keep executing run() calls until a dumped tensor passes a given, "
            "registered filter (conditional breakpoint mode)"))

    more_lines = ["    Registered filter(s):"]
    if tensor_filters:
        filter_names = []
        for filter_name in tensor_filters:
            filter_names.append(filter_name)
            command_menu_node = debugger_cli_common.MenuItem(
                "", "run -f %s" % filter_name)
            more_lines.append(
                RL("        * ") + RL(filter_name, command_menu_node))
    else:
        more_lines.append("        (None)")

    out.extend(
        debugger_cli_common.rich_text_lines_from_rich_line_list(more_lines))

    # TODO(Pariksheet): Python invoke_stepper implementation not support now.
    #    out.extend(
    #        _recommend_command(
    #            "invoke_stepper",
    #            "Use the node-stepper interface, which allows you to interactively "
    #            "step through nodes involved in the graph run() call and "
    #            "inspect/modify their values", create_link=True))

    out.append("")

    #    out.append_rich_line(RL("For more details, see ") +
    #                         RL("help.", debugger_cli_common.MenuItem("", "help")) +
    #                         ".")
    #    out.append("")

    # Make main menu for the run-start intro.
    menu = debugger_cli_common.Menu()
    menu.append(debugger_cli_common.MenuItem("run", "run"))
    # TODO(Pariksheet): Python invoke_stepper implementation not support now.
    #    menu.append(debugger_cli_common.MenuItem(
    #        "invoke_stepper", "invoke_stepper"))
    menu.append(debugger_cli_common.MenuItem("exit", "exit"))
    out.annotations[debugger_cli_common.MAIN_MENU_KEY] = menu

    return out