예제 #1
0
    def _print_feed_handler(self, args, screen_info=None):
        np_printoptions = cli_shared.numpy_printoptions_from_screen_info(
            screen_info)

        if not self._feed_dict:
            return cli_shared.error(
                "The feed_dict of the current run is None or empty.")

        parsed = self._argparsers["print_feed"].parse_args(args)
        tensor_name, tensor_slicing = (
            command_parser.parse_tensor_name_with_slicing(parsed.tensor_name))

        feed_key = None
        feed_value = None
        for key in self._feed_dict:
            key_name = common.get_graph_element_name(key)
            if key_name == tensor_name:
                feed_key = key_name
                feed_value = self._feed_dict[key]
                break

        if feed_key is None:
            return cli_shared.error(
                "The feed_dict of the current run does not contain the key %s"
                % tensor_name)
        return cli_shared.format_tensor(
            feed_value,
            feed_key + " (feed)",
            np_printoptions,
            print_all=parsed.print_all,
            tensor_slicing=tensor_slicing,
            highlight_options=cli_shared.parse_ranges_highlight(parsed.ranges),
            include_numeric_summary=parsed.numeric_summary)
예제 #2
0
    def inject_value(self, args, screen_info=None):
        """Inject value to a given tensor.

        Args:
          args: (list of str) command-line arguments for the "step" command.
          screen_info: Information about screen.

        Returns:
          (RichTextLines) Screen output for the result of the stepping action.
        """

        _ = screen_info  # Currently unused.

        if screen_info and "cols" in screen_info:
            np_printoptions = {"linewidth": screen_info["cols"]}
        else:
            np_printoptions = {}

        parsed = self.arg_parsers["inject_value"].parse_args(args)

        tensor_names = self._resolve_tensor_names(parsed.tensor_name)
        if not tensor_names:
            return cli_shared.error(self._MESSAGE_TEMPLATES["NOT_IN_CLOSURE"] %
                                    parsed.tensor_name)
        elif len(tensor_names) > 1:
            return cli_shared.error(
                self._MESSAGE_TEMPLATES["MULTIPLE_TENSORS"] %
                parsed.tensor_name)
        else:
            tensor_name = tensor_names[0]

        tensor_value = eval(parsed.tensor_value_str)  # pylint: disable=eval-used

        try:
            self._node_stepper.override_tensor(tensor_name, tensor_value)
            lines = [
                "Injected value \"%s\"" % parsed.tensor_value_str,
                "  to tensor \"%s\":" % tensor_name, ""
            ]

            tensor_lines = tensor_format.format_tensor(
                tensor_value,
                tensor_name,
                include_metadata=True,
                np_printoptions=np_printoptions).lines
            lines.extend(tensor_lines)

        except ValueError:
            lines = [
                "ERROR: Failed to inject value to tensor %s" %
                parsed.tensor_name
            ]

        return debugger_cli_common.RichTextLines(lines)
예제 #3
0
    def print_tensor(self, args, screen_info=None):
        """Print the value of a tensor that the stepper has access to."""

        parsed = self.arg_parsers["print_tensor"].parse_args(args)

        if screen_info and "cols" in screen_info:
            np_printoptions = {"linewidth": screen_info["cols"]}
        else:
            np_printoptions = {}

        # Determine if any range-highlighting is required.
        highlight_options = cli_shared.parse_ranges_highlight(parsed.ranges)

        tensor_name, tensor_slicing = (
            command_parser.parse_tensor_name_with_slicing(parsed.tensor_name))

        tensor_names = self._resolve_tensor_names(tensor_name)
        if not tensor_names:
            return cli_shared.error(
                self._MESSAGE_TEMPLATES["NOT_IN_CLOSURE"] % tensor_name)
        elif len(tensor_names) > 1:
            return cli_shared.error(
                self._MESSAGE_TEMPLATES["MULTIPLE_TENSORS"] % tensor_name)
        else:
            tensor_name = tensor_names[0]

        try:
            tensor_value = self._node_stepper.get_tensor_value(tensor_name)
        except ValueError as ex:
            return debugger_cli_common.RichTextLines([str(ex)])

        return cli_shared.format_tensor(
            tensor_value,
            tensor_name,
            np_printoptions,
            print_all=parsed.print_all,
            tensor_slicing=tensor_slicing,
            highlight_options=highlight_options)
예제 #4
0
    def cont(self, args, screen_info=None):
        """Continue-to action on the graph."""

        _ = screen_info

        parsed = self.arg_parsers["cont"].parse_args(args)

        # Determine which node is being continued to, so the _next pointer can be
        # set properly.
        node_name = parsed.target_name.split(":")[0]
        if node_name not in self._sorted_nodes:
            return cli_shared.error(self._MESSAGE_TEMPLATES["NOT_IN_CLOSURE"] %
                                    parsed.target_name)
        self._next = self._sorted_nodes.index(node_name)

        cont_result = self._node_stepper.cont(
            parsed.target_name,
            invalidate_from_updated_variables=(
                parsed.invalidate_from_updated_variables),
            restore_variable_values=parsed.restore_variable_values)
        self._completed_nodes.add(parsed.target_name.split(":")[0])

        screen_output = debugger_cli_common.RichTextLines(
            ["Continued to %s:" % parsed.target_name, ""])
        screen_output.extend(self._report_last_input_types())
        screen_output.extend(self._report_last_updated())
        screen_output.extend(
            tensor_format.format_tensor(cont_result,
                                        parsed.target_name,
                                        include_metadata=True))

        # Generate windowed view of the sorted transitive closure on which the
        # stepping is occurring.
        lower_bound = max(0, self._next - 2)
        upper_bound = min(len(self._sorted_nodes), self._next + 3)

        final_output = self.list_sorted_nodes(
            ["-l", str(lower_bound), "-u",
             str(upper_bound)])
        final_output.extend(debugger_cli_common.RichTextLines([""]))
        final_output.extend(screen_output)

        # Re-calculate the target of the next "step" action.
        self._calculate_next()

        return final_output