Пример #1
0
    def run_main_loop(self) -> None:
        if self._main_loop_task:
            self._main_loop_task.cancel()

        if self._diagnostics_loop_task:
            self._diagnostics_loop_task.cancel()

        def _reraise_if_necessary(task: asyncio.Task) -> None:
            try:
                if not task.cancelled():
                    task.result()
            except Exception as error:  # pylint: disable=broad-except
                self.logger.error(
                    f"Exiting from servo main loop do to error: {error} (task={task})"
                )
                self.logger.opt(exception=error).trace(
                    f"Exception raised by task {task}"
                )
                raise error  # Ensure that we surface the error for handling

        self._main_loop_task = asyncio.create_task(
            self.main_loop(), name=f"main loop for servo {self.optimizer.id}"
        )
        self._main_loop_task.add_done_callback(_reraise_if_necessary)

        if not servo.current_servo().config.no_diagnostics:
            diagnostics_handler = servo.telemetry.DiagnosticsHandler(self.servo)
            self._diagnostics_loop_task = asyncio.create_task(
                diagnostics_handler.diagnostics_check(),
                name=f"diagnostics for servo {self.optimizer.id}",
            )
        else:
            self.logger.info(
                f"Servo runner initialized with diagnostics polling disabled"
            )
Пример #2
0
        async def handle_progress_exception(
            progress: dict[str, Any], error: Exception
        ) -> None:
            # FIXME: This needs to be made multi-servo aware
            # Restart the main event loop if we get out of sync with the server
            if isinstance(
                error,
                (servo.errors.UnexpectedEventError, servo.errors.EventCancelledError),
            ):
                if isinstance(error, servo.errors.UnexpectedEventError):
                    self.logger.error(
                        "servo has lost synchronization with the optimizer: restarting"
                    )
                elif isinstance(error, servo.errors.EventCancelledError):
                    self.logger.error(
                        "optimizer has cancelled operation in progress: cancelling and restarting loop"
                    )

                    # Post a status to resolve the operation
                    operation = progress["operation"]
                    status = servo.api.Status.from_error(error)
                    self.logger.error(f"Responding with {status.dict()}")
                    runner = self._runner_for_servo(servo.current_servo())
                    await runner._post_event(operation, status.dict())

                tasks = [
                    t for t in asyncio.all_tasks() if t is not asyncio.current_task()
                ]
                self.logger.info(f"Cancelling {len(tasks)} outstanding tasks")
                [task.cancel() for task in tasks]

                await asyncio.gather(*tasks, return_exceptions=True)

                # Restart a fresh main loop
                if poll:
                    runner = self._runner_for_servo(servo.current_servo())
                    runner.run_main_loop()

            else:
                self.logger.error(
                    f"unrecognized exception passed to progress exception handler: {error}"
                )
Пример #3
0
    def __init__(
        self,
        message: str = "",
        reason: Optional[str] = None,
        *args,
        assembly: Optional[servo.Assembly] = None,
        servo_: Optional[servo.Servo] = None,
        connector: Optional[servo.Connector] = None,
        event: Optional[servo.Event] = None,
    ) -> None:
        super().__init__(message, *args)

        # Use the context vars to infer the assembly, servo, connector, and event
        import servo

        self._reason = reason
        self._assembly = assembly or servo.current_assembly()
        self._servo = servo_ or servo.current_servo()
        self._connector = connector or getattr(self._servo, "connector", None)
        self._event = event or getattr(self._servo, "event", None)
        self._created_at = datetime.datetime.now()
Пример #4
0
 async def _report_progress(**kwargs) -> None:
     # Forward to the active servo...
     if servo_ := servo.current_servo():
         await servo_.report_progress(**kwargs)
Пример #5
0
class Mixin(abc.ABC):
    """Provides functionality for interacting with the Opsani API via httpx.

    The mixin requires the implementation of the `api_client_options` method
    which is responsible for providing details around base URL, HTTP headers,
    timeouts, proxies, SSL configuration, etc. for initializing
    `httpx.AsyncClient` and `httpx.Client` instances.
    """
    @property
    @abc.abstractmethod
    def api_client_options(self) -> Dict[str, Any]:
        """Return a dict of options for initializing httpx API client objects.

        An implementation must be provided in subclasses derived from the mixin
        and is responsible for appropriately configuring the base URL, HTTP
        headers, timeouts, proxies, SSL configuration, transport flags, etc.

        The dict returned is passed directly to the initializer of
        `httpx.AsyncClient` and `httpx.Client` objects constructed by the
        `api_client` and `api_client_sync` methods.
        """
        ...

    def api_client(self, **kwargs) -> httpx.AsyncClient:
        """Return an asynchronous client for interacting with the Opsani API."""
        return httpx.AsyncClient(**{**self.api_client_options, **kwargs})

    def api_client_sync(self, **kwargs) -> httpx.Client:
        """Return a synchronous client for interacting with the Opsani API."""
        return httpx.Client(**{**self.api_client_options, **kwargs})

    async def report_progress(self, **kwargs) -> None:
        """Post a progress report to the Opsani API."""
        request = self.progress_request(**kwargs)
        status = await self._post_event(*request)

        if status.status == OptimizerStatuses.ok:
            pass
        elif status.status == OptimizerStatuses.unexpected_event:
            # We have lost sync with the backend, raise an exception to halt broken execution
            raise servo.errors.UnexpectedEventError(status.reason)
        elif status.status == OptimizerStatuses.cancelled:
            # Optimizer wants to cancel the operation
            raise servo.errors.EventCancelledError(status.reason
                                                   or "Command cancelled")
        elif status.status == OptimizerStatuses.invalid:
            servo.logger.warning(f"progress report was rejected as invalid")
        else:
            raise ValueError(f'unknown error status: "{status.status}"')

    def progress_request(
        self,
        operation: str,
        progress: servo.types.Numeric,
        started_at: datetime,
        message: Optional[str],
        *,
        connector: Optional[str] = None,
        event_context: Optional["servo.events.EventContext"] = None,
        time_remaining: Optional[Union[servo.types.Numeric,
                                       servo.types.Duration]] = None,
        logs: Optional[List[str]] = None,
    ) -> Tuple[str, Dict[str, Any]]:
        def set_if(d: Dict, k: str, v: Any):
            if v is not None:
                d[k] = v

        # Normalize progress to positive percentage
        if progress < 1.0:
            progress = progress * 100

        # Calculate runtime
        runtime = servo.types.Duration(datetime.now() - started_at)

        # Produce human readable and remaining time in seconds values (if given)
        if time_remaining:
            if isinstance(time_remaining, (int, float)):
                time_remaining_in_seconds = time_remaining
                time_remaining = servo.types.Duration(
                    time_remaining_in_seconds)
            elif isinstance(time_remaining, timedelta):
                time_remaining_in_seconds = time_remaining.total_seconds()
            else:
                raise ValueError(
                    f"Unknown value of type '{time_remaining.__class__.__name__}' for parameter 'time_remaining'"
                )
        else:
            time_remaining_in_seconds = None

        params = dict(
            progress=float(progress),
            runtime=float(runtime.total_seconds()),
        )
        set_if(params, "message", message)

        return (operation, params)

    def _is_fatal_status_code(error: Exception) -> bool:
        if isinstance(error, httpx.HTTPStatusError):
            if error.response.status_code < 500:
                servo.logger.error(
                    f"Giving up on non-retryable HTTP status code {error.response.status_code} ({error.response.reason_phrase}) "
                )
                return True
        return False

    @backoff.on_exception(
        backoff.expo,
        httpx.HTTPError,
        max_time=lambda: servo.current_servo() and servo.current_servo(
        ).config.settings.backoff.max_time(),
        max_tries=lambda: servo.current_servo() and servo.current_servo().
        config.settings.backoff.max_tries(),
        giveup=_is_fatal_status_code,
    )
    async def _post_event(self, event: Events,
                          param) -> Union[CommandResponse, Status]:
        async with self.api_client() as client:
            event_request = Request(event=event, param=param)
            self.logger.trace(
                f"POST event request: {devtools.pformat(event_request.json())}"
            )

            try:
                response = await client.post("servo",
                                             data=event_request.json())
                response.raise_for_status()
                response_json = response.json()
                self.logger.trace(
                    f"POST event response ({response.status_code} {response.reason_phrase}): {devtools.pformat(response_json)}"
                )
                self.logger.trace(_redacted_to_curl(response.request))

                return pydantic.parse_obj_as(Union[CommandResponse, Status],
                                             response_json)

            except httpx.HTTPError as error:
                self.logger.error(
                    f'HTTP error "{error.__class__.__name__}" encountered while posting "{event}" event: {error}, for '
                    f"url {error.request.url} \n\n Response: {devtools.pformat(error.response.text)}"
                )
                self.logger.trace(_redacted_to_curl(error.request))
                raise
Пример #6
0
class DiagnosticsHandler(servo.logging.Mixin, servo.api.Mixin):

    servo: servo.Servo = None
    _running: bool = False

    def __init__(self, servo: servo.Servo) -> None:  # noqa: D10
        self.servo = servo

    @property
    def api_client_options(self) -> dict[str, Any]:
        # Adopt the servo config for driving the API mixin
        return self.servo.api_client_options

    async def diagnostics_check(self) -> None:

        self._running = True

        while self._running:
            try:
                self.logger.trace("Polling for diagnostics request")
                request = await self._diagnostics_api(
                    method="GET",
                    endpoint=DIAGNOSTICS_CHECK_ENDPOINT,
                    output_model=DiagnosticStates,
                )

                if request == DiagnosticStates.withhold:
                    self.logger.trace("Withholding diagnostics")

                elif request == DiagnosticStates.send:
                    self.logger.info(f"Diagnostics requested, gathering and sending")
                    diagnostic_data = await self._get_diagnostics()

                    await self._diagnostics_api(
                        method="PUT",
                        endpoint=DIAGNOSTICS_OUTPUT_ENDPOINT,
                        output_model=servo.api.Status,
                        json=diagnostic_data.dict(),
                    )

                    # Reset diagnostics check state to withhold
                    reset_state = DiagnosticStates.withhold
                    await self._diagnostics_api(
                        method="PUT",
                        endpoint=DIAGNOSTICS_CHECK_ENDPOINT,
                        output_model=servo.api.Status,
                        json=reset_state,
                    )

                elif request == DiagnosticStates.stop:
                    self.logger.info(
                        f"Received request to disable polling for diagnostics"
                    )
                    self.servo.config.no_diagnostics = True
                    self._running = False
                else:
                    raise

                await asyncio.sleep(60)

            except Exception:
                self.logger.exception(
                    f"Diagnostics check failed with unrecoverable error"
                )  # exception logger logs the exception object
                self._running = False

    async def _get_diagnostics(self) -> Diagnostics:

        async with aiofiles.open(servo.logging.logs_path, "r") as log_file:
            logs = await log_file.read()

        # Strip emoji from logs :(
        raw_logs = logs.encode("ascii", "ignore").decode()

        # Limit + truncate per 1MiB /assets limit, allowing ample room for configmap
        log_data_lines = filter(None, raw_logs[-ONE_MiB - 10000 :].split("\n")[1:])
        log_dict = {}

        for line in log_data_lines:
            # Handle rare multi-line logs e.g. from self.tuning_container.resources
            try:
                time, msg = line.split("|", 1)
                log_dict[time.strip()] = msg.strip()
            except:
                log_dict[list(log_dict.keys())[-1]] += line

        # TODO: Re-evaluate roundtripping through JSON required to produce primitives
        config_dict = self.servo.config.json(exclude_unset=True, exclude_none=True)
        config_data = json.loads(config_dict)

        return Diagnostics(configmap=config_data, logs=log_dict)

    @backoff.on_exception(
        backoff.expo,
        httpx.HTTPError,
        max_time=lambda: servo.current_servo()
        and servo.current_servo().config.settings.backoff.max_time(),
        max_tries=lambda: DIAGNOSTICS_MAX_RETRIES,
        logger="diagnostics-backoff",
        on_giveup=lambda x: asyncio.current_task().cancel(),
    )
    async def _diagnostics_api(
        self,
        method: str,
        endpoint: str,
        output_model: pydantic.BaseModel,
        json: Optional[dict] = None,
    ) -> Union[DiagnosticStates, servo.api.Status]:

        async with self.api_client() as client:
            self.logger.trace(f"{method} diagnostic request")
            try:
                response = await client.request(
                    method=method, url=endpoint, json=dict(data=json)
                )
                response.raise_for_status()
                response_json = response.json()

                # Handle /diagnostics-check retrieval
                if "data" in response_json:
                    response_json = response_json["data"]

                self.logger.trace(
                    f"{method} diagnostic request response ({response.status_code} {response.reason_phrase}): {devtools.pformat(response_json)}"
                )
                self.logger.trace(servo.api._redacted_to_curl(response.request))
                try:
                    return pydantic.parse_obj_as(output_model, response_json)
                except pydantic.ValidationError as error:
                    # Should not raise due to improperly set diagnostic states
                    self.logger.exception(
                        f"Malformed diagnostic {method} response", level_id="DEBUG"
                    )
                    return DiagnosticStates.withhold

            except httpx.HTTPError as error:
                if error.response.status_code < 500:
                    self.logger.debug(
                        f"Giving up on non-retryable HTTP status code {error.response.status_code} ({error.response.reason_phrase}) for url: {error.request.url}"
                    )
                    return DiagnosticStates.withhold
                else:
                    self.logger.trace(servo.api._redacted_to_curl(error.request))
                    raise