示例#1
0
    def before_run(self, run_context):
        if not self._wrapper_initialized:
            local_cli_wrapper.LocalCLIDebugWrapperSession.__init__(
                self, run_context.session)
            self._wrapper_initialized = True

        # Increment run call counter.
        self._run_call_count += 1

        # Adapt run_context to an instance of OnRunStartRequest for invoking
        # superclass on_run_start().
        on_run_start_request = framework.OnRunStartRequest(
            run_context.original_args.fetches,
            run_context.original_args.feed_dict, None, None,
            self._run_call_count)

        on_run_start_response = self.on_run_start(on_run_start_request)
        self._performed_action = on_run_start_response.action

        run_args = session_run_hook.SessionRunArgs(
            None, feed_dict=None, options=config_pb2.RunOptions())
        if self._performed_action == framework.OnRunStartAction.DEBUG_RUN:
            self._decorate_options_for_debug(run_args.options,
                                             run_context.session.graph)
        elif self._performed_action == framework.OnRunStartAction.INVOKE_STEPPER:
            raise NotImplementedError(
                "OnRunStartAction INVOKE_STEPPER has not been implemented.")

        return run_args
示例#2
0
    def before_run(self, run_context):
        if not self._wrapper_initialized:
            local_cli_wrapper.LocalCLIDebugWrapperSession.__init__(
                self,
                run_context.session,
                ui_type=self._ui_type,
                dump_root=self._dump_root,
                thread_name_filter=self._thread_name_filter)

            # Actually register tensor filters registered prior to the construction
            # of the underlying LocalCLIDebugWrapperSession object.
            for filter_name in self._pending_tensor_filters:
                local_cli_wrapper.LocalCLIDebugWrapperSession.add_tensor_filter(
                    self, filter_name,
                    self._pending_tensor_filters[filter_name])

            self._wrapper_initialized = True

        # Increment run call counter.
        self._run_call_count += 1

        # Adapt run_context to an instance of OnRunStartRequest for invoking
        # superclass on_run_start().
        on_run_start_request = framework.OnRunStartRequest(
            run_context.original_args.fetches,
            run_context.original_args.feed_dict, None, None,
            self._run_call_count)

        on_run_start_response = self.on_run_start(on_run_start_request)
        self._performed_action = on_run_start_response.action

        run_args = session_run_hook.SessionRunArgs(
            None, feed_dict=None, options=config_pb2.RunOptions())
        if self._performed_action == framework.OnRunStartAction.DEBUG_RUN:
            self._decorate_options_for_debug(
                run_args.options, run_context.session.graph,
                framework.WatchOptions(
                    node_name_regex_whitelist=(
                        on_run_start_response.node_name_regex_whitelist),
                    op_type_regex_whitelist=(
                        on_run_start_response.op_type_regex_whitelist),
                    tensor_dtype_regex_whitelist=(
                        on_run_start_response.tensor_dtype_regex_whitelist),
                    tolerate_debug_op_creation_failures=(
                        on_run_start_response.
                        tolerate_debug_op_creation_failures)))
        elif self._performed_action == framework.OnRunStartAction.INVOKE_STEPPER:
            # The _finalized property must be set to False so that the NodeStepper
            # can insert ops for retrieving TensorHandles.
            # pylint: disable=protected-access
            run_context.session.graph._finalized = False
            # pylint: enable=protected-access

            with stepper.NodeStepper(
                    run_context.session, run_context.original_args.fetches,
                    run_context.original_args.feed_dict) as node_stepper:
                self.invoke_node_stepper(node_stepper,
                                         restore_variable_values_on_exit=True)

        return run_args
示例#3
0
    def before_run(self, run_context):
        if not self._session_wrapper:
            self._session_wrapper = local_cli_wrapper.LocalCLIDebugWrapperSession(
                run_context.session,
                ui_type=self._ui_type,
                dump_root=self._dump_root,
                thread_name_filter=self._thread_name_filter,
                config_file_path=self._config_file_path)

            # Actually register tensor filters registered prior to the construction
            # of the underlying LocalCLIDebugWrapperSession object.
            for filter_name in self._pending_tensor_filters:
                self._session_wrapper.add_tensor_filter(
                    filter_name, self._pending_tensor_filters[filter_name])

        # Increment run call counter.
        self._session_wrapper.increment_run_call_count()

        # Adapt run_context to an instance of OnRunStartRequest for invoking
        # superclass on_run_start().
        on_run_start_request = framework.OnRunStartRequest(
            run_context.original_args.fetches,
            run_context.original_args.feed_dict, None, None,
            self._session_wrapper.run_call_count)

        on_run_start_response = self._session_wrapper.on_run_start(
            on_run_start_request)
        self._performed_action = on_run_start_response.action

        run_args = session_run_hook.SessionRunArgs(
            None, feed_dict=None, options=config_pb2.RunOptions())
        if self._performed_action == framework.OnRunStartAction.DEBUG_RUN:
            # pylint: disable=protected-access
            self._session_wrapper._decorate_run_options_for_debug(
                run_args.options,
                on_run_start_response.debug_urls,
                debug_ops=on_run_start_response.debug_ops,
                node_name_regex_allowlist=(
                    on_run_start_response.node_name_regex_allowlist),
                op_type_regex_allowlist=(
                    on_run_start_response.op_type_regex_allowlist),
                tensor_dtype_regex_allowlist=(
                    on_run_start_response.tensor_dtype_regex_allowlist),
                tolerate_debug_op_creation_failures=(
                    on_run_start_response.tolerate_debug_op_creation_failures))
            # pylint: enable=protected-access
        elif self._performed_action == framework.OnRunStartAction.PROFILE_RUN:
            # pylint: disable=protected-access
            self._session_wrapper._decorate_run_options_for_profile(
                run_args.options)
            # pylint: enable=protected-access

        return run_args
示例#4
0
    def before_run(self, run_context):
        if not self._wrapper_initialized:
            local_cli_wrapper.LocalCLIDebugWrapperSession.__init__(
                self, run_context.session, ui_type=self._ui_type)
            self._wrapper_initialized = True

        # Increment run call counter.
        self._run_call_count += 1

        # Adapt run_context to an instance of OnRunStartRequest for invoking
        # superclass on_run_start().
        on_run_start_request = framework.OnRunStartRequest(
            run_context.original_args.fetches,
            run_context.original_args.feed_dict, None, None,
            self._run_call_count)

        on_run_start_response = self.on_run_start(on_run_start_request)
        self._performed_action = on_run_start_response.action

        run_args = session_run_hook.SessionRunArgs(
            None, feed_dict=None, options=config_pb2.RunOptions())
        if self._performed_action == framework.OnRunStartAction.DEBUG_RUN:
            self._decorate_options_for_debug(run_args.options,
                                             run_context.session.graph)
        elif self._performed_action == framework.OnRunStartAction.INVOKE_STEPPER:
            # The _finalized property must be set to False so that the NodeStepper
            # can insert ops for retrieving TensorHandles.
            # pylint: disable=protected-access
            run_context.session.graph._finalized = False
            # pylint: enable=protected-access

            self.invoke_node_stepper(stepper.NodeStepper(
                run_context.session, run_context.original_args.fetches,
                run_context.original_args.feed_dict),
                                     restore_variable_values_on_exit=True)

        return run_args