Пример #1
0
class AwsTestConfig(CommonTestConfigMixin, AwsConfig):
    _initial_state: Dict[str, Any] = pydantic.PrivateAttr(default_factory=dict)
    _nimbo_config_file_exists: bool = pydantic.PrivateAttr(default=True)

    def inject_required_config(self, *cases: RequiredCase) -> None:
        cases = RequiredCase.decompose(*cases)

        if RequiredCase.NONE in cases:
            self.telemetry = False
        if RequiredCase.MINIMAL in cases:
            self.region_name = "eu-west-1"
        if RequiredCase.STORAGE in cases:
            self.local_datasets_path = "random_datasets"
            self.local_results_path = "random_results"
            self.s3_datasets_path = S3_DATASETS_PATH
            self.s3_results_path = S3_RESULTS_PATH
        if RequiredCase.INSTANCE in cases:
            self.image = "ubuntu18-latest-drivers"
            self.disk_size = 64
            self.instance_type = "p2.xlarge"
            self.instance_key = self._find_instance_key(
                ASSETS_PATH, self.region_name)
        if RequiredCase.JOB in cases:
            self.conda_env = CONDA_ENV

    @staticmethod
    def _find_instance_key(directory: str, region_name: str) -> str:
        # Instance keys are supposed to be prefixed with region name for testing
        for file in os.listdir(directory):
            if file.startswith(region_name) and file.endswith(".pem"):
                return file
Пример #2
0
class NimboTestConfig(NimboConfig):
    _initial_state: Dict[str, Any] = pydantic.PrivateAttr(default_factory=dict)
    _nimbo_config_file_exists: bool = pydantic.PrivateAttr(default=True)

    def save_initial_state(self):
        self._initial_state = self.dict()

    def reset_required_config(self) -> None:
        if len(self._initial_state) <= 1:
            raise ValueError(
                "You must run save_initial_state to use reset_required_config")

        # user_id is set after save_initial_state
        if not self._initial_state["user_id"]:
            self._initial_state["user_id"] = self.user_id

        for key, value in self._initial_state.items():
            setattr(self, key, value)

    def inject_required_config(self, *cases: RequiredCase) -> None:
        cases = RequiredCase.decompose(*cases)

        if RequiredCase.NONE in cases:
            self.telemetry = False
        if RequiredCase.MINIMAL in cases:
            self.region_name = "eu-west-1"
        if RequiredCase.STORAGE in cases:
            self.local_datasets_path = "random_datasets"
            self.local_results_path = "random_results"
            self.s3_datasets_path = S3_DATASETS_PATH
            self.s3_results_path = S3_RESULTS_PATH
        if RequiredCase.INSTANCE in cases:
            self.image = "ubuntu18-latest-drivers"
            self.disk_size = 64
            self.instance_type = "p2.xlarge"
            self.instance_key = self._find_instance_key(
                ASSETS_PATH, self.region_name)
        if RequiredCase.JOB in cases:
            self.conda_env = CONDA_ENV

    @staticmethod
    def _find_instance_key(directory: str, region_name: str) -> str:
        # Instance keys are supposed to be prefixed with region name for testing
        for file in os.listdir(directory):
            if file.startswith(region_name) and file.endswith(".pem"):
                return file
Пример #3
0
class AssemblyRunner(pydantic.BaseModel, servo.logging.Mixin):
    assembly: servo.Assembly
    runners: list[ServoRunner] = []
    progress_handler: Optional[servo.logging.ProgressHandler] = None
    progress_handler_id: Optional[int] = None
    _running: bool = pydantic.PrivateAttr(False)

    class Config:
        arbitrary_types_allowed = True

    def __init__(self, assembly: servo.Assembly, **kwargs) -> None:
        super().__init__(assembly=assembly, **kwargs)

    def _runner_for_servo(self, servo: servo.Servo) -> ServoRunner:
        for runner in self.runners:
            if runner.servo == servo:
                return runner

        raise KeyError(f'no runner was found for the servo: "{servo}"')

    @property
    def running(self) -> bool:
        return self._running

    def run(
        self, *, poll: bool = True, interactive: bool = False, debug: bool = False
    ) -> None:
        """Asynchronously run all servos active within the assembly.

        Running the assembly takes over the current event loop and schedules a `ServoRunner` instance for each servo active in the assembly.
        """
        if self.running:
            raise RuntimeError("Cannot run an assembly that is already running")

        self._running = True
        loop = asyncio.get_event_loop()

        # Setup signal handling
        signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT, signal.SIGUSR1)
        for s in signals:
            loop.add_signal_handler(
                s, lambda s=s: asyncio.create_task(self._shutdown(loop, signal=s))
            )

        if not debug:
            loop.set_exception_handler(self._handle_exception)
        else:
            loop.set_exception_handler(None)

        # Setup logging
        async def _report_progress(**kwargs) -> None:
            # Forward to the active servo...
            if servo_ := servo.current_servo():
                await servo_.report_progress(**kwargs)
            else:
Пример #4
0
class Document(models.BaseModel):
    _doc_id: int = pydantic.PrivateAttr()

    def __init__(self, value: dict, doc_id: int):
        super().__init__(**value)

        self._doc_id = doc_id

    @property
    def doc_id(self):
        return self._doc_id
Пример #5
0
class BaseConfiguration(AbstractBaseConfiguration):
    """
    BaseConfiguration is the base configuration class for Opsani Servo Connectors.

    BaseConfiguration subclasses are typically paired 1:1 with a Connector class
    that inherits from `servo.connector.Connector` and implements the business logic
    of the connector. Configuration classes are connector specific and designed
    to be initialized from commandline arguments, environment variables, and defaults.
    Connectors are initialized with a valid settings instance capable of providing necessary
    configuration for the connector to function.

    An optional textual description of the configuration stanza useful for differentiating
    between configurations within assemblies.
    """

    description: Optional[str] = pydantic.Field(
        None, description="An optional description of the configuration.")
    __optimizer__: Optional[Optimizer] = pydantic.PrivateAttr(None)
    __settings__: Optional[CommonConfiguration] = pydantic.PrivateAttr(
        default_factory=lambda: CommonConfiguration(), )

    def __init__(
        self,
        __optimizer__: Optional[Optimizer] = None,
        __settings__: Optional[CommonConfiguration] = None,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        self.__optimizer__ = __optimizer__
        if __settings__:
            self.__settings__ = __settings__

    @property
    def optimizer(self) -> Optional[Optimizer]:
        """Returns the Optimizer this configuration is bound to."""
        return self.__optimizer__

    @property
    def settings(self) -> Optional[Optimizer]:
        """Returns the Optimizer this configuration is bound to."""
        return self.__settings__
Пример #6
0
class Telemetry(pydantic.BaseModel):
    """Class and convenience methods for storage of arbitrary servo metadata"""

    _values: dict[str, str] = pydantic.PrivateAttr(default_factory=dict)

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self["servox.version"] = str(servo.__version__)
        self["servox.platform"] = platform.platform()

        if servo_ns := os.environ.get("POD_NAMESPACE"):
            self["servox.namespace"] = servo_ns
Пример #7
0
class BigQueryConfig(BaseTimeWindowConfig, SQLAlchemyConfig):
    scheme: str = "bigquery"
    project_id: Optional[str] = None
    lineage_client_project_id: Optional[str] = None

    log_page_size: Optional[pydantic.PositiveInt] = 1000
    credential: Optional[BigQueryCredential]
    # extra_client_options, include_table_lineage and max_query_duration are relevant only when computing the lineage.
    extra_client_options: Dict[str, Any] = {}
    include_table_lineage: Optional[bool] = True
    max_query_duration: timedelta = timedelta(minutes=15)

    credentials_path: Optional[str] = None
    bigquery_audit_metadata_datasets: Optional[List[str]] = None
    use_exported_bigquery_audit_metadata: bool = False
    use_date_sharded_audit_log_tables: bool = False
    _credentials_path: Optional[str] = pydantic.PrivateAttr(None)
    use_v2_audit_metadata: Optional[bool] = False

    def __init__(self, **data: Any):
        super().__init__(**data)

        if self.credential:
            self._credentials_path = self.credential.create_credential_temp_file(
            )
            logger.debug(
                f"Creating temporary credential file at {self._credentials_path}"
            )
            os.environ[
                "GOOGLE_APPLICATION_CREDENTIALS"] = self._credentials_path

    def get_sql_alchemy_url(self):
        if self.project_id:
            return f"{self.scheme}://{self.project_id}"
        # When project_id is not set, we will attempt to detect the project ID
        # based on the credentials or environment variables.
        # See https://github.com/mxmzdlv/pybigquery#authentication.
        return f"{self.scheme}://"

    @pydantic.validator("platform_instance")
    def bigquery_doesnt_need_platform_instance(cls, v):
        if v is not None:
            raise ConfigurationError(
                "BigQuery project ids are globally unique. You do not need to specify a platform instance."
            )

    @pydantic.validator("platform")
    def platform_is_always_bigquery(cls, v):
        return "bigquery"
Пример #8
0
class Mixin(pydantic.BaseModel):
    """Provides convenience interfaces for working with asyncrhonously repeating tasks."""

    __private_attributes__ = {
        "_repeating_tasks": pydantic.PrivateAttr({}),
    }

    def __init_subclass__(cls, **kwargs) -> None:  # noqa: D105
        super().__init_subclass__(**kwargs)

        repeaters = {}
        for name, method in cls.__dict__.items():
            if repeat_params := getattr(method, "__repeating__", None):
                repeaters[method] = repeat_params

        cls.__repeaters__ = repeaters
Пример #9
0
class BigQueryUsageConfig(DatasetSourceConfigBase, BaseUsageConfig):
    projects: Optional[List[str]] = None
    project_id: Optional[str] = None  # deprecated in favor of `projects`
    extra_client_options: dict = {}
    table_pattern: Optional[AllowDenyPattern] = None

    log_page_size: Optional[pydantic.PositiveInt] = 1000
    query_log_delay: Optional[pydantic.PositiveInt] = None
    max_query_duration: timedelta = timedelta(minutes=15)
    use_v2_audit_metadata: Optional[bool] = False
    credential: Optional[BigQueryCredential]
    _credentials_path: Optional[str] = pydantic.PrivateAttr(None)

    def __init__(self, **data: Any):
        super().__init__(**data)
        if self.credential:
            self._credentials_path = self.credential.create_credential_temp_file(
            )
            logger.debug(
                f"Creating temporary credential file at {self._credentials_path}"
            )
            os.environ[
                "GOOGLE_APPLICATION_CREDENTIALS"] = self._credentials_path

    @pydantic.validator("project_id")
    def note_project_id_deprecation(cls, v, values, **kwargs):
        logger.warning(
            "bigquery-usage project_id option is deprecated; use projects instead"
        )
        values["projects"] = [v]
        return None

    @pydantic.validator("platform")
    def platform_is_always_bigquery(cls, v):
        return "bigquery"

    @pydantic.validator("platform_instance")
    def bigquery_platform_instance_is_meaningless(cls, v):
        raise ConfigurationError(
            "BigQuery project-ids are globally unique. You don't need to provide a platform_instance"
        )

    def get_allow_pattern_string(self) -> str:
        return "|".join(self.table_pattern.allow) if self.table_pattern else ""

    def get_deny_pattern_string(self) -> str:
        return "|".join(self.table_pattern.deny) if self.table_pattern else ""
Пример #10
0
class Artifact(pydantic.BaseModel):
    """A pydantic model for representing an artifact in the cache."""

    key: str
    serializer: str
    load_kwargs: typing.Optional[typing.Dict] = pydantic.Field(
        default_factory=dict)
    dump_kwargs: typing.Optional[typing.Dict] = pydantic.Field(
        default_factory=dict)
    additional_metadata: typing.Optional[typing.Dict] = pydantic.Field(
        default_factory=dict)
    created_at: typing.Optional[datetime.datetime] = pydantic.Field(
        default_factory=datetime.datetime.utcnow)
    _value: typing.Any = pydantic.PrivateAttr(default=None)

    class Config:
        validate_assignment = True
Пример #11
0
class InputSpec(BaseModel):
    """Component input definitions.

    Attributes:
        type: The type of the input.
        default: Optional; the default value for the input.
        description: Optional: the user description of the input.
        optional: Wether the input is optional. An input is optional when it has
            an explicit default value.
    """
    type: Union[str, dict]
    default: Optional[Any] = None
    description: Optional[str] = None
    _optional: bool = pydantic.PrivateAttr()

    def __init__(self, **data):
        super().__init__(**data)
        # An input is optional if a default value is explicitly specified.
        self._optional = 'default' in data

    @property
    def optional(self) -> bool:
        return self._optional
Пример #12
0
class BigQueryUsageConfig(DatasetSourceConfigBase, BaseUsageConfig):
    projects: Optional[List[str]] = pydantic.Field(
        default=None,
        description="List of project ids to ingest usage from. If not specified, will infer from environment.",
    )
    project_id: Optional[str] = pydantic.Field(
        default=None,
        description="Project ID to ingest usage from. If not specified, will infer from environment. Deprecated in favour of projects ",
    )
    extra_client_options: dict = pydantic.Field(
        default_factory=dict,
        description="Additional options to pass to google.cloud.logging_v2.client.Client.",
    )
    use_v2_audit_metadata: Optional[bool] = pydantic.Field(
        default=False,
        description="Whether to ingest logs using the v2 format. Required if use_exported_bigquery_audit_metadata is set to True.",
    )

    bigquery_audit_metadata_datasets: Optional[List[str]] = pydantic.Field(
        description="A list of datasets that contain a table named cloudaudit_googleapis_com_data_access which contain BigQuery audit logs, specifically, those containing BigQueryAuditMetadata. It is recommended that the project of the dataset is also specified, for example, projectA.datasetB.",
    )
    use_exported_bigquery_audit_metadata: bool = pydantic.Field(
        default=False,
        description="When configured, use BigQueryAuditMetadata in bigquery_audit_metadata_datasets to compute usage information.",
    )

    use_date_sharded_audit_log_tables: bool = pydantic.Field(
        default=False,
        description="Whether to read date sharded tables or time partitioned tables when extracting usage from exported audit logs.",
    )

    table_pattern: AllowDenyPattern = pydantic.Field(
        default=AllowDenyPattern.allow_all(),
        description="List of regex patterns for tables to include/exclude from ingestion.",
    )
    dataset_pattern: AllowDenyPattern = pydantic.Field(
        default=AllowDenyPattern.allow_all(),
        description="List of regex patterns for datasets to include/exclude from ingestion.",
    )
    log_page_size: pydantic.PositiveInt = pydantic.Field(
        default=1000,
        description="",
    )

    query_log_delay: Optional[pydantic.PositiveInt] = pydantic.Field(
        default=None,
        description="To account for the possibility that the query event arrives after the read event in the audit logs, we wait for at least query_log_delay additional events to be processed before attempting to resolve BigQuery job information from the logs. If query_log_delay is None, it gets treated as an unlimited delay, which prioritizes correctness at the expense of memory usage.",
    )

    max_query_duration: timedelta = pydantic.Field(
        default=timedelta(minutes=15),
        description="Correction to pad start_time and end_time with. For handling the case where the read happens within our time range but the query completion event is delayed and happens after the configured end time.",
    )

    credential: Optional[BigQueryCredential] = pydantic.Field(
        default=None,
        description="Bigquery credential. Required if GOOGLE_APPLICATION_CREDENTIALS enviroment variable is not set. See this example recipe for details",
    )
    _credentials_path: Optional[str] = pydantic.PrivateAttr(None)
    temp_table_dataset_prefix: str = pydantic.Field(
        default="_",
        description="If you are creating temp tables in a dataset with a particular prefix you can use this config to set the prefix for the dataset. This is to support workflows from before bigquery's introduction of temp tables. By default we use `_` because of datasets that begin with an underscore are hidden by default https://cloud.google.com/bigquery/docs/datasets#dataset-naming.",
    )

    def __init__(self, **data: Any):
        super().__init__(**data)
        if self.credential:
            self._credentials_path = self.credential.create_credential_temp_file()
            logger.debug(
                f"Creating temporary credential file at {self._credentials_path}"
            )
            os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self._credentials_path

    @pydantic.validator("project_id")
    def note_project_id_deprecation(cls, v, values, **kwargs):
        logger.warning(
            "bigquery-usage project_id option is deprecated; use projects instead"
        )
        values["projects"] = [v]
        return None

    @pydantic.validator("platform")
    def platform_is_always_bigquery(cls, v):
        return "bigquery"

    @pydantic.validator("platform_instance")
    def bigquery_platform_instance_is_meaningless(cls, v):
        raise ConfigurationError(
            "BigQuery project-ids are globally unique. You don't need to provide a platform_instance"
        )

    @pydantic.validator("use_exported_bigquery_audit_metadata")
    def use_exported_bigquery_audit_metadata_uses_v2(cls, v, values):
        if v is True and not values["use_v2_audit_metadata"]:
            raise ConfigurationError(
                "To use exported BigQuery audit metadata, you must also use v2 audit metadata"
            )
        return v

    def get_table_allow_pattern_string(self) -> str:
        return "|".join(self.table_pattern.allow) if self.table_pattern else ""

    def get_table_deny_pattern_string(self) -> str:
        return "|".join(self.table_pattern.deny) if self.table_pattern else ""

    def get_dataset_allow_pattern_string(self) -> str:
        return "|".join(self.dataset_pattern.allow) if self.table_pattern else ""

    def get_dataset_deny_pattern_string(self) -> str:
        return "|".join(self.dataset_pattern.deny) if self.table_pattern else ""
Пример #13
0
class NimboConfig(pydantic.BaseModel):
    class Config:
        title = "Nimbo configuration"
        extra = "forbid"

    aws_profile: Optional[str] = None
    region_name: Optional[str] = None

    local_datasets_path: Optional[str] = None
    local_results_path: Optional[str] = None
    s3_datasets_path: Optional[str] = None
    s3_results_path: Optional[str] = None
    encryption: _Encryption = None

    instance_type: Optional[str] = None
    image: str = "ubuntu18-latest-drivers"
    disk_size: Optional[int] = None
    disk_iops: pydantic.conint(ge=0) = None
    disk_type: _DiskType = _DiskType.GP2
    spot: bool = False
    spot_duration: pydantic.conint(ge=60, le=360, multiple_of=60) = None
    security_group: Optional[str] = None
    instance_key: Optional[str] = None

    conda_env: Optional[str] = None
    run_in_background: bool = False
    persist: bool = False

    ssh_timeout: pydantic.conint(strict=True, ge=0) = 180
    telemetry: bool = True

    # The following are defined internally
    nimbo_config_file: str = NIMBO_CONFIG_FILE
    _nimbo_config_file_exists: bool = pydantic.PrivateAttr(
        default=os.path.isfile(NIMBO_CONFIG_FILE))
    user_id: Optional[str] = None
    telemetry_url: str = TELEMETRY_URL

    def get_session(self) -> boto3.Session:
        session = boto3.Session(profile_name=self.aws_profile,
                                region_name=self.region_name)

        self.user_id = session.client("sts").get_caller_identity()["Arn"]
        return session

    def assert_required_config_exists(self, *cases: RequiredCase) -> None:
        """ Designed to be used with the assert_required_config annotation """

        cases = RequiredCase.decompose(*cases)

        if len(cases) == 1 and RequiredCase.NONE in cases:
            return
        elif not self._nimbo_config_file_exists:
            raise FileNotFoundError(
                f"Nimbo configuration file '{self.nimbo_config_file}' not found.\n"
                "Run 'nimbo generate-config' to create the default config file."
            )

        required_config = {}

        if RequiredCase.MINIMAL in cases:
            required_config["aws_profile"] = self.aws_profile
            required_config["region_name"] = self.region_name
        if RequiredCase.STORAGE in cases:
            required_config["local_results_path"] = self.local_results_path
            required_config["local_datasets_path"] = self.local_datasets_path
            required_config["s3_results_path"] = self.s3_results_path
            required_config["s3_datasets_path"] = self.s3_datasets_path
        if RequiredCase.INSTANCE in cases:
            required_config["instance_type"] = self.instance_type
            required_config["disk_size"] = self.disk_size
            required_config["instance_key"] = self.instance_key
            required_config["security_group"] = self.security_group
        if RequiredCase.JOB in cases:
            required_config["conda_env"] = self.conda_env

        unspecified = [
            key for key, value in required_config.items() if not value
        ]
        if unspecified:
            raise AssertionError(
                f"For running this command {', '.join(unspecified)} should"
                f" be specified in {self.nimbo_config_file}")

        bad_fields = {}

        if RequiredCase.MINIMAL in cases:
            bad_fields["aws_profile"] = self._aws_profile_exists()
            bad_fields["region_name"] = self._region_name_valid()
        if RequiredCase.STORAGE in cases:
            bad_fields[
                "local_results_path"] = self._local_results_not_outside_project(
                )
            bad_fields[
                "local_datasets_path"] = self._local_datasets_not_outside_project(
                )
        if RequiredCase.INSTANCE in cases:
            bad_fields["instance_key"] = self._instance_key_valid()
            bad_fields["disk_iops"] = self._disk_iops_specified_when_needed()
        if RequiredCase.JOB in cases:
            bad_fields["conda_env"] = self._conda_env_valid()

        bad_fields = [(key, error) for key, error in bad_fields.items()
                      if error]

        if bad_fields:
            print(
                f"{len(bad_fields)} error{'' if len(bad_fields) == 1 else 's'} "
                f"in {NIMBO_CONFIG_FILE}\n")
            for key, error in bad_fields:
                print(key)
                print(f"  {error}")
            sys.exit(1)

    def _aws_profile_exists(self) -> Optional[str]:
        if self.aws_profile not in botocore.session.Session(
        ).available_profiles:
            return f"AWS Profile '{self.aws_profile}' could not be found"

    def _conda_env_valid(self) -> Optional[str]:
        if os.path.isabs(self.conda_env):
            return "conda_env should be a relative path"
        if ".." in self.conda_env:
            return "conda_env should not be outside of the project directory"
        if not os.path.isfile(self.conda_env):
            return f"file '{self.conda_env}' does not exist in the project directory"

    def _instance_key_valid(self) -> Optional[str]:
        if not os.path.isfile(self.instance_key):
            return f"file '{self.instance_key}' does not exist"

        permission = str(oct(os.stat(self.instance_key).st_mode))[-3:]
        if permission[1:] != "00":
            return (
                f"run 'chmod 400 {self.instance_key}' so that only you can read the key"
            )

    def _region_name_valid(self) -> Optional[str]:
        region_names = FULL_REGION_NAMES.keys()
        if self.region_name not in region_names:
            return (f"unknown region '{self.region_name}', "
                    f"expected to be one of {', '.join(region_names)}")

    def _local_results_not_outside_project(self) -> Optional[str]:
        if os.path.isabs(self.local_results_path):
            return "local_results_path should be a relative path"
        if ".." in self.local_results_path:
            return "local_results_path should not be outside of the project directory"

    def _local_datasets_not_outside_project(self) -> Optional[str]:
        if os.path.isabs(self.local_datasets_path):
            return "local_datasets_path should be a relative path"
        if ".." in self.local_datasets_path:
            return "local_datasets_path should not be outside of the project directory"

    def _disk_iops_specified_when_needed(self) -> Optional[str]:
        if self.disk_type in [_DiskType.IO1, _DiskType.IO2
                              ] and not self.disk_iops:
            return (
                "for disk types io1 or io2, the 'disk_iops' parameter has "
                "to be specified.\nPlease visit "
                "https://docs.nimbo.sh/nimbo-config-file-options for more details."
            )

    @pydantic.validator("nimbo_config_file")
    def _nimbo_config_file_unchanged(cls, value):
        if value != NIMBO_CONFIG_FILE:
            raise ValueError("overriding nimbo config file name is forbidden")
        return value

    @pydantic.validator("telemetry_url")
    def _nimbo_telemetry_url_unchanged(cls, value):
        if value != TELEMETRY_URL:
            raise ValueError("overriding telemetry url is forbidden")
        return value

    def save_initial_state(self) -> None:
        raise NotImplementedError(
            "save_initial_state is only available for NimboTestConfig")

    def reset_required_config(self) -> None:
        raise NotImplementedError(
            "reset_required_config is only available for NimboTestConfig")

    def inject_required_config(self, *cases: RequiredCase) -> None:
        raise NotImplementedError(
            "inject_required_config is only available for NimboTestConfig")
Пример #14
0
class ServoRunner(pydantic.BaseModel, servo.logging.Mixin, servo.api.Mixin):
    interactive: bool = False
    _servo: servo.Servo = pydantic.PrivateAttr(None)
    _connected: bool = pydantic.PrivateAttr(False)
    _running: bool = pydantic.PrivateAttr(False)
    _main_loop_task: Optional[asyncio.Task] = pydantic.PrivateAttr(None)
    _diagnostics_loop_task: Optional[asyncio.Task] = pydantic.PrivateAttr(None)

    class Config:
        arbitrary_types_allowed = True

    def __init__(self, servo_: servo, **kwargs) -> None:  # noqa: D10
        super().__init__(**kwargs)
        self._servo = servo_

        # initialize default servo options if not configured
        if self.config.settings is None:
            self.config.settings = servo.CommonConfiguration()

    @property
    def servo(self) -> servo.Servo:
        return self._servo

    @property
    def running(self) -> bool:
        return self._running

    @property
    def connected(self) -> bool:
        return self._connected

    @property
    def optimizer(self) -> servo.Optimizer:
        return self.servo.optimizer

    @property
    def config(self) -> servo.BaseServoConfiguration:
        return self.servo.config

    @property
    def api_client_options(self) -> dict[str, Any]:
        # Adopt the servo config for driving the API mixin
        return self.servo.api_client_options

    async def describe(self, control: Control) -> Description:
        self.logger.info("Describing...")

        aggregate_description = Description.construct()
        results: list[servo.EventResult] = await self.servo.dispatch_event(
            servo.Events.describe, control=control
        )
        for result in results:
            description = result.value
            aggregate_description.components.extend(description.components)
            aggregate_description.metrics.extend(description.metrics)

        return aggregate_description

    async def measure(self, param: servo.api.MeasureParams) -> Measurement:
        if isinstance(param, dict):
            # required parsing has failed in api.Mixin._post_event(), run parse_obj to surface the validation errors
            servo.api.MeasureParams.parse_obj(param)
        servo.logger.info(f"Measuring... [metrics={', '.join(param.metrics)}]")
        servo.logger.trace(devtools.pformat(param))

        aggregate_measurement = Measurement.construct()
        results: list[servo.EventResult] = await self.servo.dispatch_event(
            servo.Events.measure, metrics=param.metrics, control=param.control
        )
        for result in results:
            measurement = result.value
            aggregate_measurement.readings.extend(measurement.readings)
            aggregate_measurement.annotations.update(measurement.annotations)

        return aggregate_measurement

    async def adjust(
        self, adjustments: list[Adjustment], control: Control
    ) -> Description:
        summary = f"[{', '.join(list(map(str, adjustments)))}]"
        self.logger.info(f"Adjusting... {summary}")
        self.logger.trace(devtools.pformat(adjustments))
        self.logger.trace(devtools.pformat(control))

        aggregate_description = Description.construct()
        results = await self.servo.dispatch_event(
            servo.Events.adjust, adjustments=adjustments, control=control
        )
        for result in results:
            description = result.value
            aggregate_description.components.extend(description.components)
            aggregate_description.metrics.extend(description.metrics)

        self.logger.success(f"Adjustment completed {summary}")
        return aggregate_description

    async def exec_command(self) -> servo.api.Status:
        cmd_response = await self._post_event(servo.api.Events.whats_next, None)
        self.logger.info(f"What's Next? => {cmd_response.command}")
        self.logger.trace(devtools.pformat(cmd_response))

        if cmd_response.command == servo.api.Commands.describe:
            description = await self.describe(
                Control(**cmd_response.param.get("control", {}))
            )
            self.logger.success(
                f"Described: {len(description.components)} components, {len(description.metrics)} metrics"
            )
            self.logger.debug(devtools.pformat(description))

            status = servo.api.Status.ok(descriptor=description.__opsani_repr__())
            return await self._post_event(servo.api.Events.describe, status.dict())

        elif cmd_response.command == servo.api.Commands.measure:
            try:
                measurement = await self.measure(cmd_response.param)
                self.logger.success(
                    f"Measured: {len(measurement.readings)} readings, {len(measurement.annotations)} annotations"
                )
                self.logger.trace(devtools.pformat(measurement))
                param = measurement.__opsani_repr__()
            except servo.errors.EventError as error:
                self.logger.error(f"Measurement failed: {error}")
                param = servo.api.Status.from_error(error).dict()
                self.logger.error(f"Responding with {param}")
                self.logger.opt(exception=error).debug("Measure failure details")

            return await self._post_event(servo.api.Events.measure, param)

        elif cmd_response.command == servo.api.Commands.adjust:
            adjustments = servo.api.descriptor_to_adjustments(
                cmd_response.param["state"]
            )
            control = Control(**cmd_response.param.get("control", {}))

            try:
                description = await self.adjust(adjustments, control)
                status = servo.api.Status.ok(state=description.__opsani_repr__())

                components_count = len(description.components)
                settings_count = sum(
                    len(component.settings) for component in description.components
                )
                self.logger.success(
                    f"Adjusted: {components_count} components, {settings_count} settings"
                )
            except servo.EventError as error:
                self.logger.error(f"Adjustment failed: {error}")
                status = servo.api.Status.from_error(error)
                self.logger.error(f"Responding with {status.dict()}")
                self.logger.opt(exception=error).debug("Adjust failure details")

            return await self._post_event(servo.api.Events.adjust, status.dict())

        elif cmd_response.command == servo.api.Commands.sleep:
            # TODO: Model this
            duration = Duration(cmd_response.param.get("duration", 120))
            status = servo.utilities.key_paths.value_for_key_path(
                cmd_response.param, "data.status", None
            )
            reason = servo.utilities.key_paths.value_for_key_path(
                cmd_response.param, "data.reason", "unknown reason"
            )
            msg = f"{status}: {reason}" if status else f"{reason}"
            self.logger.info(f"Sleeping for {duration} ({msg}).")
            await asyncio.sleep(duration.total_seconds())

            # Return a status so we have a simple API contract
            return servo.api.Status(status="ok", message=msg)
        else:
            raise ValueError(f"Unknown command '{cmd_response.command.value}'")

    # Main run loop for processing commands from the optimizer
    async def main_loop(self) -> None:
        # FIXME: We have seen exceptions from using `with self.servo.current()` crossing contexts
        _set_current_servo(self.servo)

        while self._running:
            try:
                if self.interactive:
                    if not typer.confirm("Poll for next command?"):
                        typer.echo("Sleeping for 1m")
                        await asyncio.sleep(60)
                        continue

                status = await self.exec_command()
                if status.status == servo.api.OptimizerStatuses.unexpected_event:
                    self.logger.warning(
                        f"server reported unexpected event: {status.reason}"
                    )

            except (httpx.TimeoutException, httpx.HTTPStatusError) as error:
                self.logger.warning(
                    f"command execution failed HTTP client error: {error}"
                )

            except pydantic.ValidationError as error:
                self.logger.warning(
                    f"command execution failed with model validation error: {error}"
                )
                self.logger.opt(exception=error).debug(
                    "Pydantic model failed validation"
                )

            except Exception as error:
                self.logger.exception(f"failed with unrecoverable error: {error}")
                raise error

    def run_main_loop(self) -> None:
        if self._main_loop_task:
            self._main_loop_task.cancel()

        if self._diagnostics_loop_task:
            self._diagnostics_loop_task.cancel()

        def _reraise_if_necessary(task: asyncio.Task) -> None:
            try:
                if not task.cancelled():
                    task.result()
            except Exception as error:  # pylint: disable=broad-except
                self.logger.error(
                    f"Exiting from servo main loop do to error: {error} (task={task})"
                )
                self.logger.opt(exception=error).trace(
                    f"Exception raised by task {task}"
                )
                raise error  # Ensure that we surface the error for handling

        self._main_loop_task = asyncio.create_task(
            self.main_loop(), name=f"main loop for servo {self.optimizer.id}"
        )
        self._main_loop_task.add_done_callback(_reraise_if_necessary)

        if not servo.current_servo().config.no_diagnostics:
            diagnostics_handler = servo.telemetry.DiagnosticsHandler(self.servo)
            self._diagnostics_loop_task = asyncio.create_task(
                diagnostics_handler.diagnostics_check(),
                name=f"diagnostics for servo {self.optimizer.id}",
            )
        else:
            self.logger.info(
                f"Servo runner initialized with diagnostics polling disabled"
            )

    async def run(self, *, poll: bool = True) -> None:
        self._running = True

        _set_current_servo(self.servo)
        await self.servo.startup()
        self.logger.info(
            f"Servo started with {len(self.servo.connectors)} active connectors [{self.optimizer.id} @ {self.optimizer.url or self.optimizer.base_url}]"
        )

        async def giveup() -> None:
            loop = asyncio.get_event_loop()
            self.logger.critical("retries exhausted, giving up")
            asyncio.create_task(self.shutdown(loop))

        try:

            @backoff.on_exception(
                backoff.expo,
                httpx.HTTPError,
                max_time=lambda: self.config.settings.backoff.max_time(),
                max_tries=lambda: self.config.settings.backoff.max_tries(),
                on_giveup=giveup,
            )
            async def connect() -> None:
                self.logger.info("Saying HELLO.", end=" ")
                await self._post_event(
                    servo.api.Events.hello,
                    dict(
                        agent=servo.api.user_agent(),
                        telemetry=self.servo.telemetry.values,
                    ),
                )
                self._connected = True

            self.logger.info(
                f"Connecting to Opsani Optimizer @ {self.optimizer.url}..."
            )
            if self.interactive:
                typer.confirm("Connect to the optimizer?", abort=True)

            await connect()
        except typer.Abort:
            # Rescue abort and notify user
            servo.logger.warning("Operation aborted. Use Control-C to exit")
        except asyncio.CancelledError as error:
            self.logger.trace("task cancelled, aborting servo runner")
            raise error
        except:
            self.logger.exception("exception encountered during connect")

        if poll:
            self.run_main_loop()
        else:
            self.logger.warning(
                f"Servo runner initialized with polling disabled -- command loop is not running"
            )

    async def shutdown(self, *, reason: Optional[str] = None) -> None:
        """Shutdown the running servo."""
        try:
            self._running = False
            if self.connected:
                await self._post_event(servo.api.Events.goodbye, dict(reason=reason))
        except Exception:
            self.logger.exception(f"Exception occurred during GOODBYE request")
Пример #15
0
class EventProgress(BaseProgress):
    """EventProgress objects track progress against an indeterminate event."""

    timeout: Optional[Duration] = None
    """The maximum amount of time to wait for the event to be triggered.

    When None, the event will be awaited forever.
    """

    settlement: Optional[Duration] = None
    """The amount of time to wait for progress to be reset following an event trigger before returning early.

    When None, progress is returned immediately upon the event being triggered.
    """

    _event: asyncio.Event = pydantic.PrivateAttr(default_factory=asyncio.Event)
    _settlement_timer: Optional[asyncio.TimerHandle] = pydantic.PrivateAttr(
        None)
    _settlement_started_at: Optional[datetime.datetime] = pydantic.PrivateAttr(
        None)

    def __init__(
        self,
        timeout: Optional["Duration"] = None,
        settlement: Optional["Duration"] = None,
        **kwargs,
    ) -> None:  # noqa: D107
        super().__init__(timeout=timeout, settlement=settlement, **kwargs)

    def complete(self) -> None:
        """Advance progress immediately to completion.

        This method does not respect settlement time. Typical operation should utilize the `trigger`
        method.
        """
        self._event.set()

    @property
    def completed(self) -> bool:
        """Return True if the progress has been completed."""
        return self._event.is_set()

    @property
    def timed_out(self) -> bool:
        """Return True if the timeout has elapsed.

        Return False if there is no timeout configured or the progress has not been started.
        """
        if self.timeout == 0 and self.started:
            return True
        if not self.timeout or not self.started:
            return False
        return Duration.since(self.started_at) >= self.timeout

    @property
    def finished(self) -> bool:
        return self.timed_out or super().finished

    def trigger(self) -> None:
        """Trigger the event to advance progress toward completion.

        When the event is triggered, the behavior is dependent upon whether or not a
        settlement duration is configured. When None, progress is immediately advanced to 100%
        and progress is finished, notifying all observers.

        When a settlement duration is configured, progress will begin advancing across the settlement
        duration to allow for the progress to be reset.
        """
        if self.settlement:
            self._settlement_started_at = datetime.datetime.now()
            self._settlement_timer = asyncio.get_event_loop().call_later(
                self.settlement.total_seconds(), self.complete)
        else:
            self.complete()

    def reset(self) -> None:
        """Reset progress to zero by clearing the event trigger.

        Resetting progress does not affect the timeout which will eventually finalize progress
        when elapsed.
        """
        if self._settlement_timer:
            self._settlement_timer.cancel()
        self._settlement_started_at = None
        self._event.clear()

    async def wait(self) -> None:
        """Asynchronously wait until the event condition has been triggered.

        If the progress was initialized with a timeout, raises a TimeoutError when the timeout is
        elapsed.

        Raises:
            TimeoutError: Raised if the timeout elapses before the event is triggered.
        """
        timeout = self.timeout.total_seconds() if self.timeout else None
        await asyncio.wait_for(self._event.wait(), timeout=timeout)

    @property
    def settling(self) -> bool:
        """Return True if the progress has been triggered but is awaiting settlement before completion."""
        return self._settlement_started_at is not None

    @property
    def settlement_remaining(self) -> Optional[Duration]:
        """Return the amount of settlement time remaining before completion."""
        if self.settling:
            duration = Duration(self.settlement -
                                Duration.since(self._settlement_started_at))
            return duration if duration.total_seconds() >= 0 else None
        else:
            return None

    @property
    def progress(self) -> float:
        """Return completion progress percentage as a floating point value from 0.0 to 100.0

        If the event has been triggered, immediately returns 100.0.
        When progress has started but has not yet completed, the behavior is conditional upon
        the configuration of a timeout and/or settlement time.

        When settlement is in effect, progress is relative to the amount of time remaining in the
        settlement duration. This can result in progress that goes backward as the finish moves
        forward based on the event condition being triggered.
        """
        if self._event.is_set():
            return 100.0
        elif self.started:
            if self.settling:
                return min(
                    100.0,
                    100.0 * (Duration.since(self._settlement_started_at) /
                             self.settlement),
                )
            elif self.timeout:
                return min(100.0, 100.0 * (self.elapsed / self.timeout))

        # NOTE: Without a timeout or settlement duration we advance from 0 to 100. Like a true gangsta
        return 0.0

    async def watch(
        self,
        notify: Callable[["DurationProgress"], Union[None, Awaitable[None]]],
        every: Optional[Duration] = None,
    ) -> None:
        # NOTE: Handle the case where reporting interval < timeout (matters mostly for tests)
        if every is None:
            if self.timeout is None:
                every = Duration("5s")
            else:
                every = min(Duration("5s"), self.timeout)

        return await super().watch(notify, every)
Пример #16
0
class BigQueryConfig(BaseTimeWindowConfig, SQLAlchemyConfig):
    scheme: str = "bigquery"
    project_id: Optional[str] = pydantic.Field(
        default=None,
        description=
        "Project ID to ingest from. If not specified, will infer from environment.",
    )
    lineage_client_project_id: Optional[str] = pydantic.Field(
        default=None,
        description=
        "If you want to use a different ProjectId for the lineage collection you can set it here.",
    )
    log_page_size: pydantic.PositiveInt = pydantic.Field(
        default=1000,
        description=
        "The number of log item will be queried per page for lineage collection",
    )
    credential: Optional[BigQueryCredential] = pydantic.Field(
        description="BigQuery credential informations")
    # extra_client_options, include_table_lineage and max_query_duration are relevant only when computing the lineage.
    extra_client_options: Dict[str, Any] = pydantic.Field(
        default={},
        description=
        "Additional options to pass to google.cloud.logging_v2.client.Client.",
    )
    include_table_lineage: Optional[bool] = pydantic.Field(
        default=True,
        description=
        "Option to enable/disable lineage generation. Is enabled by default.",
    )
    max_query_duration: timedelta = pydantic.Field(
        default=timedelta(minutes=15),
        description=
        "Correction to pad start_time and end_time with. For handling the case where the read happens within our time range but the query completion event is delayed and happens after the configured end time.",
    )
    bigquery_audit_metadata_datasets: Optional[List[str]] = pydantic.Field(
        default=None,
        description=
        "A list of datasets that contain a table named cloudaudit_googleapis_com_data_access which contain BigQuery audit logs, specifically, those containing BigQueryAuditMetadata. It is recommended that the project of the dataset is also specified, for example, projectA.datasetB.",
    )
    use_exported_bigquery_audit_metadata: bool = pydantic.Field(
        default=False,
        description=
        "When configured, use BigQueryAuditMetadata in bigquery_audit_metadata_datasets to compute lineage information.",
    )
    use_date_sharded_audit_log_tables: bool = pydantic.Field(
        default=False,
        description=
        "Whether to read date sharded tables or time partitioned tables when extracting usage from exported audit logs.",
    )
    _credentials_path: Optional[str] = pydantic.PrivateAttr(None)
    temp_table_dataset_prefix: str = pydantic.Field(
        default="_",
        description=
        "If you are creating temp tables in a dataset with a particular prefix you can use this config to set the prefix for the dataset. This is to support workflows from before bigquery's introduction of temp tables. By default we use `_` because of datasets that begin with an underscore are hidden by default https://cloud.google.com/bigquery/docs/datasets#dataset-naming.",
    )
    use_v2_audit_metadata: Optional[bool] = pydantic.Field(
        default=False,
        description="Whether to ingest logs using the v2 format.")
    upstream_lineage_in_report: bool = pydantic.Field(
        default=False,
        description=
        "Useful for debugging lineage information. Set to True to see the raw lineage created internally.",
    )

    def __init__(self, **data: Any):
        super().__init__(**data)

        if self.credential:
            self._credentials_path = self.credential.create_credential_temp_file(
            )
            logger.debug(
                f"Creating temporary credential file at {self._credentials_path}"
            )
            os.environ[
                "GOOGLE_APPLICATION_CREDENTIALS"] = self._credentials_path

    def get_sql_alchemy_url(self):
        if self.project_id:
            return f"{self.scheme}://{self.project_id}"
        # When project_id is not set, we will attempt to detect the project ID
        # based on the credentials or environment variables.
        # See https://github.com/mxmzdlv/pybigquery#authentication.
        return f"{self.scheme}://"

    @pydantic.validator("platform_instance")
    def bigquery_doesnt_need_platform_instance(cls, v):
        if v is not None:
            raise ConfigurationError(
                "BigQuery project ids are globally unique. You do not need to specify a platform instance."
            )

    @pydantic.root_validator()
    def validate_that_bigquery_audit_metadata_datasets_is_correctly_configured(
            cls, values: Dict[str, Any]) -> Dict[str, Any]:
        if (values.get("use_exported_bigquery_audit_metadata")
                and not values.get("use_v2_audit_metadata")
                and not values.get("bigquery_audit_metadata_datasets")):
            raise ConfigurationError(
                "bigquery_audit_metadata_datasets must be specified if using exported audit metadata. Otherwise set use_v2_audit_metadata to True."
            )
            pass
        return values

    @pydantic.validator("platform")
    def platform_is_always_bigquery(cls, v):
        return "bigquery"
Пример #17
0
class Servo(servo.connector.BaseConnector):
    """A connector that interacts with the Opsani API to perform optimization.

    The `Servo` is a core object of the `servo` package. It manages a set of
    connectors that provide integration and interactivity to external services
    such as metrics collectors, orchestation systems, load generators, etc. The
    Servo acts primarily as an event gateway between the Opsani API and its child
    connectors.

    Servo objects are configured with a dynamically created class that is built by
    the `servo.Assembly` class. Servo objects are typically not created directly
    and are instead built through the `Assembly.assemble` method.

    Attributes:
        connectors...
        config...
    """

    config: servo.configuration.BaseServoConfiguration
    """Configuration of the Servo assembly.

    Note that the configuration is built dynamically at Servo assembly time.
    The concrete type is created in `Assembly.assemble()` and adds a field for each active
    connector.
    """

    connectors: list[servo.connector.BaseConnector]
    """The active connectors in the Servo.
    """

    _running: bool = pydantic.PrivateAttr(False)

    async def dispatch_event(
        self, *args, **kwargs
    ) -> Union[Optional[servo.events.EventResult],
               list[servo.events.EventResult]]:
        with self.current():
            return await super().dispatch_event(*args, **kwargs)

    def __init__(self, *args, connectors: list[servo.connector.BaseConnector],
                 **kwargs) -> None:  # noqa: D107
        super().__init__(*args, connectors=[], **kwargs)

        # Ensure the connectors refer to the same objects by identity (required for eventing)
        self.connectors.extend(connectors)

        # associate shared config with our children
        for connector in connectors + [self]:
            connector._global_config = self.config.settings

    @pydantic.root_validator()
    def _initialize_name(cls, values: dict[str, Any]) -> dict[str, Any]:
        if values["name"] == "servo" and values.get("config"):
            values["name"] = values["config"].name or getattr(
                values["config"].optimizer, "id", "servo")

        return values

    async def attach(self, servo_: servo.assembly.Assembly) -> None:
        """Notify the servo that it has been attached to an Assembly."""
        await self.dispatch_event(Events.attach, self)

    async def detach(self, servo_: servo.assembly.Assembly) -> None:
        """Notify the servo that it has been detached from an Assembly."""
        await self.dispatch_event(Events.detach, self)

    @property
    def is_running(self) -> bool:
        """Return True if the servo is running."""
        return self._running

    async def startup(self) -> None:
        """Notify all active connectors that the servo is starting up."""
        if self.is_running:
            raise RuntimeError(
                "Cannot start up a servo that is already running")

        self._running = True

        await self.dispatch_event(Events.startup,
                                  _prepositions=servo.events.Preposition.on)

        # Start up the pub/sub exchange
        if not self.pubsub_exchange.running:
            self.pubsub_exchange.start()

    async def shutdown(self) -> None:
        """Notify all active connectors that the servo is shutting down."""
        if not self.is_running:
            raise RuntimeError("Cannot shut down a servo that is not running")

        # Remove all the connectors (dispatches shutdown event)
        await asyncio.gather(*list(map(self.remove_connector, self.connectors))
                             )

        # Shut down the pub/sub exchange
        if self.pubsub_exchange.running:
            await self.pubsub_exchange.shutdown()

        self._running = False

    @property
    def all_connectors(self) -> list[servo.connector.BaseConnector]:
        """Return a list of all active connectors including the Servo."""
        return [self, *self.connectors]

    def connectors_named(self,
                         names: Sequence[str]) -> list[servo.BaseConnector]:
        return [
            connector for connector in self.all_connectors
            if connector.name in names
        ]

    def get_connector(
        self, name: Union[str, Sequence[str]]
    ) -> Optional[Union[servo.connector.BaseConnector,
                        list[servo.connector.BaseConnector]]]:
        """Return one or more connectors by name.

        This is a convenience method equivalent to iterating `connectors` and comparing by name.

        When given a single name, returns the connector or `None` if not found.
        When given a sequence of names, returns a list of Connectors for all connectors found.
        """
        if isinstance(name, str):
            return next(iter(self.connectors_named([name])), None)
        else:
            return self.connectors_named(name)

    async def add_connector(self, name: str,
                            connector: servo.connector.BaseConnector) -> None:
        """Add a connector to the servo.

        The connector is added to the servo event bus and is initialized with
        the `attach` event to prepare for execution. If the servo is currently
        running, the connector is sent the `startup` event as well.

        Args:
            name: A unique name for the connector in the servo.
            connector: The connector to be added to the servo.

        Raises:
            ValueError: Raised if the name is not unique in the servo.
        """
        if self.get_connector(name):
            raise ValueError(
                f"invalid name: a connector named '{name}' already exists in the servo"
            )

        connector.name = name
        connector._global_config = self.config.settings

        # Add to the event bus
        self.connectors.append(connector)
        self.__connectors__.append(connector)

        # Add to the pub/sub exchange
        connector.pubsub_exchange = self.pubsub_exchange

        # Register our name into the config class
        with servo.utilities.pydantic.extra(self.config):
            setattr(self.config, name, connector.config)

        await self.dispatch_event(Events.attach, self, include=[connector])

        # Start the connector if we are running
        if self.is_running:
            await self.dispatch_event(
                Events.startup,
                include=[connector],
                _prepositions=servo.events.Preposition.on,
            )

    async def remove_connector(
            self, connector: Union[str,
                                   servo.connector.BaseConnector]) -> None:
        """Remove a connector from the servo.

        The connector is removed from the servo event bus and is finalized with
        the detach event to prepare for eviction. If the servo is currently running,
        the connector is sent the shutdown event as well.

        Args:
            connector: The connector or name to remove from the servo.

        Raises:
            ValueError: Raised if the connector does not exist in the servo.
        """
        connector_ = (connector
                      if isinstance(connector, servo.connector.BaseConnector)
                      else self.get_connector(connector))
        if not connector_ in self.connectors:
            name = connector_.name if connector_ else connector
            raise ValueError(
                f"invalid connector: a connector named '{name}' does not exist in the servo"
            )

        # Shut the connector down if we are running
        if self.is_running:
            await self.dispatch_event(
                Events.shutdown,
                include=[connector_],
                _prepositions=servo.events.Preposition.on,
            )

        await self.dispatch_event(Events.detach, self, include=[connector])

        # Remove from the event bus
        self.connectors.remove(connector_)
        self.__connectors__.remove(connector_)

        # Remove from the pub/sub exchange
        connector_.cancel_subscribers()
        connector_.cancel_publishers()
        connector_.pubsub_exchange = servo.pubsub.Exchange()

        with servo.utilities.pydantic.extra(self.config):
            delattr(self.config, connector_.name)

    def top_level_schema(self, *, all: bool = False) -> dict[str, Any]:
        """Return a schema that only includes connector model definitions"""
        connectors = servo.Assembly.all_connector_types(
        ) if all else self.connectors
        config_models = list(map(lambda c: c.config_model(), connectors))
        return pydantic.schema.schema(config_models, title="Servo Schema")

    def top_level_schema_json(self, *, all: bool = False) -> str:
        """Return a JSON string representation of the top level schema"""
        return json.dumps(
            self.top_level_schema(all=all),
            indent=2,
            default=pydantic.json.pydantic_encoder,
        )

    async def check_servo(self,
                          print_callback: Callable[[str],
                                                   None] = None) -> bool:

        connectors = self.config.checks.connectors
        name = self.config.checks.name
        id = self.config.checks.id
        tag = self.config.checks.tag

        quiet = self.config.checks.quiet
        progressive = self.config.checks.progressive
        wait = self.config.checks.wait
        delay = self.config.checks.delay
        halt_on = self.config.checks.halt_on

        # Validate that explicit args support check events
        connector_objs = (
            self.connectors_named(connectors) if connectors else list(
                filter(
                    lambda c: c.responds_to_event(servo.Events.check),
                    self.all_connectors,
                )))
        if not connector_objs:
            if connectors:
                raise servo.ConnectorNotFoundError(
                    f"no connector found with name(s) '{connectors}'")
            else:
                raise servo.EventHandlersNotFoundError(
                    f"no currently assembled connectors respond to the check event"
                )
        validate_connectors_respond_to_event(connector_objs,
                                             servo.Events.check)

        if wait:
            summary = "Running checks"
            summary += " progressively" if progressive else ""
            summary += f" for up to {wait} with a delay of {delay} between iterations"
            servo.logger.info(summary)

        passing = set()
        progress = servo.DurationProgress(servo.Duration(wait or 0))
        ready = False

        while not progress.finished:
            if not progress.started:
                # run at least one time
                progress.start()

            args = dict(
                name=servo.utilities.parse_re(name),
                id=servo.utilities.parse_id(id),
                tags=servo.utilities.parse_csv(tag),
            )
            constraints = dict(filter(lambda i: bool(i[1]), args.items()))

            results: List[servo.EventResult] = (await self.dispatch_event(
                servo.Events.check,
                servo.CheckFilter(**constraints),
                include=self.all_connectors,
                halt_on=halt_on,
            ) or [])

            ready = await servo.checks.CheckHelpers.process_checks(
                checks_config=self.config.checks,
                results=results,
                passing=passing,
            )
            if not progressive and not quiet:
                output = await servo.checks.CheckHelpers.checks_to_table(
                    checks_config=self.config.checks, results=results)
                print_callback(output)

            if ready:
                return ready
            else:
                if wait and delay is not None:
                    servo.logger.info(
                        f"waiting for {delay} before rerunning failing checks")
                    await asyncio.sleep(servo.Duration(delay).total_seconds())

                if progress.finished:
                    # Don't log a timeout if we aren't running in wait mode
                    if progress.duration:
                        servo.logger.error(
                            f"timed out waiting for checks to pass {progress.duration}"
                        )
                    return ready

    ##
    # Event handlers

    @servo.events.on_event()
    async def check(
        self,
        matching: Optional[servo.checks.CheckFilter],
        halt_on: Optional[
            servo.types.ErrorSeverity] = servo.types.ErrorSeverity.critical,
    ) -> list[servo.checks.Check]:
        """Check that the servo is ready to perform optimization.

        Args:
            matching: An optional filter to limit the checks that are executed.
            halt_on: The severity level of errors that should halt execution of checks.

        Returns:
            A list of check objects that describe the outcomes of the checks that were run.
        """
        try:
            async with self.api_client() as client:
                event_request = servo.api.Request(event=servo.api.Events.hello)
                response = await client.post("servo",
                                             data=event_request.json())
                success = response.status_code == httpx.codes.OK
                return [
                    servo.checks.Check(
                        name="Opsani API connectivity",
                        success=success,
                        message=f"Response status code: {response.status_code}",
                    )
                ]
        except Exception as error:
            return [
                servo.checks.Check(
                    name="Opsani API connectivity",
                    success=False,
                    message=str(error),
                )
            ]

    @contextlib.contextmanager
    def current(self):
        """A context manager that sets the current servo context."""
        try:
            token = _current_context_var.set(self)
            yield self

        finally:
            _current_context_var.reset(token)
Пример #18
0
class VegetaConfiguration(servo.BaseConfiguration):
    """
    Configuration of the Vegeta connector
    """

    rate: str = pydantic.Field(
        description=
        "Specifies the request rate per time unit to issue against the targets. Given in the format of request/time unit.",
    )
    format: TargetFormat = pydantic.Field(
        TargetFormat.http,
        description=
        "Specifies the format of the targets input. Valid values are http and json. Refer to the Vegeta docs for details.",
    )
    target: Optional[str] = pydantic.Field(
        description=
        "Specifies a single formatted Vegeta target to load. See the format option to learn about available target formats. This option is exclusive of the targets option and will provide a target to Vegeta via stdin."
    )
    targets: Optional[pydantic.FilePath] = pydantic.Field(
        description=
        "Specifies the file from which to read targets. See the format option to learn about available target formats. This option is exclusive of the target option and will provide targets to via through a file on disk."
    )
    connections: int = pydantic.Field(
        10000,
        description=
        "Specifies the maximum number of idle open connections per target host.",
    )
    workers: int = pydantic.Field(
        10,
        description=
        "Specifies the initial number of workers used in the attack. The workers will automatically increase to achieve the target request rate, up to max-workers.",
    )
    max_workers: Optional[int] = pydantic.Field(
        None,
        description=
        "The maximum number of workers used to sustain the attack. This can be used to control the concurrency of the attack to simulate a target number of clients.",
    )
    max_body: int = pydantic.Field(
        -1,
        description=
        "Specifies the maximum number of bytes to capture from the body of each response. Remaining unread bytes will be fully read but discarded.",
    )
    http2: bool = pydantic.Field(
        True,
        description=
        "Specifies whether to enable HTTP/2 requests to servers which support it.",
    )
    keepalive: bool = pydantic.Field(
        True,
        description=
        "Specifies whether to reuse TCP connections between HTTP requests.",
    )
    insecure: bool = pydantic.Field(
        False,
        description=
        "Specifies whether to ignore invalid server TLS certificates.",
    )
    reporting_interval: servo.Duration = pydantic.Field(
        "15s",
        description="How often to report metrics during a measurement cycle.",
    )
    _duration: servo.Duration = pydantic.PrivateAttr(None)

    @property
    def duration(self) -> Optional[servo.Duration]:
        if self._duration is not None:
            return servo.Duration.validate(self._duration)
        else:
            return None

    @pydantic.root_validator(pre=True)
    @classmethod
    def validate_target(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        target, targets = servo.values_for_keys(values, ("target", "targets"))
        if target is None and targets is None:
            raise ValueError("target or targets must be configured")

        if target and targets:
            raise ValueError("target and targets cannot both be configured")

        return values

    @staticmethod
    def target_json_schema() -> Dict[str, Any]:
        """
        Returns the parsed JSON Schema for validating Vegeta targets in the JSON format.
        """
        schema_path = pathlib.Path(
            __file__).parent / "vegeta_target_schema.json"
        return json.load(open(schema_path))

    @pydantic.validator("target", "targets")
    @classmethod
    def validate_target_format(
        cls,
        value: Union[str, pydantic.FilePath],
        field: pydantic.Field,
        values: Dict[str, Any],
    ) -> str:
        if value is None:
            return value

        format: TargetFormat = values.get("format")
        if field.name == "target":
            value_stream = io.StringIO(value)
        elif field.name == "targets":
            value_stream = open(value)
        else:
            raise ValueError(f"unknown field '{field.name}'")

        if format == TargetFormat.http:
            # Scan through the targets and run basic heuristics
            # We don't validate ordering to avoid building a full parser
            count = 0
            for line in value_stream:
                count = count + 1
                line = line.strip()
                if len(line) == 0 or line[0] in ("#", "@"):
                    continue

                maybe_method_and_url = line.split(" ", 2)
                if (len(maybe_method_and_url) == 2
                        and maybe_method_and_url[0] in servo.HTTP_METHODS):
                    if re.match("https?://*", maybe_method_and_url[1]):
                        continue

                maybe_header_and_value = line.split(":", 2)
                if len(maybe_header_and_value
                       ) == 2 and maybe_header_and_value[1]:
                    continue

                raise ValueError(f"invalid target: {line}")

            if count == 0:
                raise ValueError(f"no targets found")

        elif format == TargetFormat.json:
            try:
                data = json.load(value_stream)
            except json.JSONDecodeError as e:
                raise ValueError(f"{field.name} contains invalid JSON") from e

            # Validate the target data with JSON Schema
            try:
                jsonschema.validate(instance=data,
                                    schema=cls.target_json_schema())
            except jsonschema.ValidationError as error:
                raise ValueError(
                    f"Invalid Vegeta JSON target: {error.message}") from error

        return value

    @pydantic.validator("rate")
    @classmethod
    def validate_rate(cls, v: Union[int, str]) -> str:
        assert isinstance(
            v,
            (int, str)), "rate must be an integer or a rate descriptor string"

        # Integer rates
        if isinstance(v, int) or v.isdigit():
            return str(v)

        # Check for hits/interval
        components = v.split("/")
        assert len(
            components) == 2, "rate strings are of the form hits/interval"

        hits = components[0]
        duration = components[1]
        assert hits.isnumeric(), "rate must have an integer hits component"

        # Try to parse it from Golang duration string
        try:
            servo.Duration(duration)
        except ValueError as e:
            raise ValueError(
                f"Invalid duration '{duration}' in rate '{v}'") from e

        return v

    @classmethod
    def generate(cls, **kwargs) -> "VegetaConfiguration":
        return cls(
            rate="50/1s",
            target="GET https://example.com/",
            description=
            "Update the rate and target/targets to match your load profile",
            **kwargs,
        )

    class Config:
        json_encoders = servo.BaseConfiguration.json_encoders(
            {TargetFormat: lambda t: t.value()})
Пример #19
0
class GcpTestConfig(CommonTestConfigMixin, GcpConfig):
    _initial_state: Dict[str, Any] = pydantic.PrivateAttr(default_factory=dict)
    _nimbo_config_file_exists: bool = pydantic.PrivateAttr(default=True)

    def inject_required_config(self, *cases: RequiredCase) -> None:
        pass
Пример #20
0
class FastFailObserver(pydantic.BaseModel):
    config: servo.configuration.FastFailConfiguration
    input: servo.types.SloInput
    metrics_getter: Callable[[datetime.datetime, datetime.datetime],
                             Awaitable[Dict[str, List[servo.types.Reading]]], ]

    _results: Dict[servo.types.SloCondition,
                   List[SloOutcome]] = pydantic.PrivateAttr(
                       default=collections.defaultdict(list))

    async def observe(self, progress: servo.EventProgress) -> None:
        if progress.elapsed < self.config.skip:
            return

        checked_at = datetime.datetime.now()
        metrics = await self.metrics_getter(checked_at - self.config.span,
                                            checked_at)
        self.check_readings(metrics=metrics, checked_at=checked_at)

    def check_readings(
        self,
        metrics: Dict[str, List[servo.types.Reading]],
        checked_at: datetime.datetime,
    ) -> None:
        failures: Dict[servo.types.SloCondition, List[SloOutcome]] = {}
        for condition in self.input.conditions:
            result_args = dict(checked_at=checked_at)
            # Evaluate target metric
            metric_readings = metrics.get(condition.metric)
            if not metric_readings:
                self._results[condition].append(
                    SloOutcome(**result_args,
                               status=SloOutcomeStatus.missing_metric))
                continue

            metric_value = _get_scalar_from_readings(metric_readings)
            result_args.update(metric_value=metric_value,
                               metric_readings=metric_readings)

            if self.config.treat_zero_as_missing and float(metric_value) == 0:
                self._results[condition].append(
                    SloOutcome(**result_args,
                               status=SloOutcomeStatus.missing_metric))
                continue

            # Evaluate threshold
            threshold_readings = None
            if condition.threshold is not None:
                threshold_value = condition.threshold * condition.threshold_multiplier

                result_args.update(
                    threshold_value=threshold_value,
                    threshold_readings=threshold_readings,
                )
            elif condition.threshold_metric is not None:
                threshold_readings = metrics.get(condition.threshold_metric)
                if not threshold_readings:
                    self._results[condition].append(
                        SloOutcome(**result_args,
                                   status=SloOutcomeStatus.missing_threshold))
                    continue

                threshold_scalar = _get_scalar_from_readings(
                    threshold_readings)
                threshold_value = threshold_scalar * condition.threshold_multiplier

                result_args.update(
                    threshold_value=threshold_value,
                    threshold_readings=threshold_readings,
                )

                if self.config.treat_zero_as_missing and float(
                        threshold_value) == 0:
                    self._results[condition].append(
                        SloOutcome(**result_args,
                                   status=SloOutcomeStatus.missing_threshold))
                    continue

                elif 0 <= metric_value <= condition.slo_metric_minimum:
                    self._results[condition].append(
                        SloOutcome(**result_args,
                                   status=SloOutcomeStatus.zero_metric))
                    continue

                elif 0 <= threshold_value <= condition.slo_threshold_minimum:
                    self._results[condition].append(
                        SloOutcome(**result_args,
                                   status=SloOutcomeStatus.zero_threshold))
                    continue

            if metric_value.is_nan() or threshold_value.is_nan():
                self._results[condition].append(
                    SloOutcome(**result_args,
                               status=SloOutcomeStatus.missing_threshold))
                continue

            # Check target against threshold
            check_passed_op = _get_keep_operator(condition.keep)
            if check_passed_op(metric_value, threshold_value):
                self._results[condition].append(
                    SloOutcome(**result_args, status=SloOutcomeStatus.passed))
            else:
                self._results[condition].append(
                    SloOutcome(**result_args, status=SloOutcomeStatus.failed))

            # Update window by slicing last n items from list where n is trigger_window
            self._results[condition] = self._results[condition][
                -condition.trigger_window:]

            if (len(
                    list(
                        filter(
                            lambda res: res.status == SloOutcomeStatus.failed,
                            self._results[condition],
                        ))) >= condition.trigger_count):
                failures[condition] = self._results[condition]

        servo.logger.debug(f"SLO results: {devtools.pformat(self._results)}")

        # Log the latest results
        last_results_buckets: Dict[SloOutcomeStatus,
                                   List[str]] = collections.defaultdict(list)
        for condition, results_list in self._results.items():
            last_result = results_list[-1]
            last_results_buckets[last_result.status].append(str(condition))

        last_results_messages: List[str] = []
        for status, condition_str_list in last_results_buckets.items():
            last_results_messages.append(
                f"x{len(condition_str_list)} {status} [{', '.join(condition_str_list)}]"
            )

        servo.logger.info(
            f"SLO statuses from last check: {', '.join(last_results_messages)}"
        )

        if failures:
            raise servo.errors.EventAbortedError(
                f"SLO violation(s) observed: {_get_results_str(failures)}",
                reason=SLO_FAILED_REASON,
            )
Пример #21
0
class BaseConnector(
        servo.utilities.associations.Mixin,
        servo.api.Mixin,
        servo.events.Mixin,
        servo.logging.Mixin,
        servo.pubsub.Mixin,
        servo.repeating.Mixin,
        pydantic.BaseModel,
        abc.ABC,
        metaclass=servo.events.Metaclass,
):
    """Connectors expose functionality to Servo assemblies by connecting external services and resources."""

    ##
    # Connector metadata

    name: str = None
    """Name of the connector, by default derived from the class name.
    """

    full_name: ClassVar[str] = None
    """The full name of the connector for referencing it unambiguously.
    """

    version: ClassVar[Version] = None
    """Semantic Versioning string of the connector.
    """

    cryptonym: ClassVar[Optional[str]] = None
    """Optional code name of the version.
    """

    description: ClassVar[Optional[str]] = None
    """Optional textual description of the connector.
    """

    homepage: ClassVar[Optional[pydantic.HttpUrl]] = None
    """Link to the homepage of the connector.
    """

    license: ClassVar[Optional[License]] = None
    """An enumerated value that identifies the license that the connector is distributed under.
    """

    maturity: ClassVar[Optional[Maturity]] = None
    """An enumerated value that identifies the self-selected maturity level of the connector, provided for
    advisory purposes.
    """

    ##
    # Instance configuration

    config: servo.configuration.BaseConfiguration
    """Configuration for the connector set explicitly or loaded from a config file."""

    # TODO: needs better name... BaseCommonConfiguration? attr can be _base_config or __base_config__
    # NOTE: __shared__ maybe?
    _global_config: servo.configuration.CommonConfiguration = pydantic.PrivateAttr(
        default_factory=lambda: servo.configuration.CommonConfiguration())
    """Shared configuration from our parent Servo instance."""
    @property
    def optimizer(self) -> Optional[servo.configuration.Optimizer]:
        """The optimizer for the connector."""
        return self.config.optimizer

    ##
    # Shared telemetry metadata
    telemetry: servo.telemetry.Telemetry = pydantic.Field(
        default_factory=servo.telemetry.Telemetry)

    ##
    # Validators

    @pydantic.root_validator(pre=True)
    @classmethod
    def _validate_metadata(cls, v):
        assert cls.name is not None, "name must be provided"
        assert cls.version is not None, "version must be provided"
        if isinstance(cls.version, str):
            # Attempt to parse
            cls.version = Version.parse(cls.version)
        assert isinstance(
            cls.version,
            Version), "version is not a semantic versioning descriptor"

        if not cls.__default_name__:
            if name := _name_for_connector_class(cls):
                cls.__default_name__ = name
            else:
                raise ValueError(
                    f"A default connector name could not be constructed for class '{cls}'"
                )
        return v
Пример #22
0
class ESMCatalogModel(pydantic.BaseModel):
    """
    Pydantic model for the ESM data catalog defined in https://git.io/JBWoW
    """

    esmcat_version: pydantic.StrictStr
    attributes: typing.List[Attribute]
    assets: Assets
    aggregation_control: AggregationControl
    id: typing.Optional[str] = ''
    catalog_dict: typing.Optional[typing.List[typing.Dict]] = None
    catalog_file: pydantic.StrictStr = None
    description: pydantic.StrictStr = None
    title: pydantic.StrictStr = None
    _df: typing.Optional[typing.Any] = pydantic.PrivateAttr()

    class Config:
        validate_all = True
        validate_assignment = True

    @pydantic.root_validator
    def validate_catalog(cls, values):
        catalog_dict, catalog_file = values.get('catalog_dict'), values.get('catalog_file')
        if catalog_dict is not None and catalog_file is not None:
            raise ValueError('catalog_dict and catalog_file cannot be set at the same time')

        return values

    @classmethod
    def from_dict(cls, data: typing.Dict) -> 'ESMCatalogModel':
        esmcat = data['esmcat']
        df = data['df']
        cat = cls.parse_obj(esmcat)
        cat._df = df
        return cat

    def save(self, name: str, *, directory: str = None, catalog_type: str = 'dict') -> None:
        """
        Save the catalog to a file.

        Parameters
        -----------
        name: str
            The name of the file to save the catalog to.
        directory: str
            The directory to save the catalog to. If None, use the current directory
        catalog_type: str
            The type of catalog to save. Whether to save the catalog table as a dictionary
            in the JSON file or as a separate CSV file. Valid options are 'dict' and 'file'.

        Notes
        -----
        Large catalogs can result in large JSON files. To keep the JSON file size manageable, call with
        `catalog_type='file'` to save catalog as a separate CSV file.

        """

        if catalog_type not in {'file', 'dict'}:
            raise ValueError(
                f'catalog_type must be either "dict" or "file". Received catalog_type={catalog_type}'
            )
        csv_file_name = pathlib.Path(f'{name}.csv.gz')
        json_file_name = pathlib.Path(f'{name}.json')
        if directory:
            directory = pathlib.Path(directory)
            directory.mkdir(parents=True, exist_ok=True)
            csv_file_name = directory / csv_file_name
            json_file_name = directory / json_file_name

        data = self.dict().copy()
        for key in {'catalog_dict', 'catalog_file'}:
            data.pop(key, None)
        data['id'] = name

        if catalog_type == 'file':
            data['catalog_file'] = str(csv_file_name)
            self.df.to_csv(csv_file_name, compression='gzip', index=False)
        else:
            data['catalog_dict'] = self.df.to_dict(orient='records')

        with open(json_file_name, 'w') as outfile:
            json.dump(data, outfile, indent=2)

        print(f'Successfully wrote ESM collection json file to: {json_file_name}')

    @classmethod
    def load(
        cls,
        json_file: typing.Union[str, pydantic.FilePath, pydantic.AnyUrl],
        storage_options: typing.Dict[str, typing.Any] = None,
        read_csv_kwargs: typing.Dict[str, typing.Any] = None,
    ) -> 'ESMCatalogModel':
        """
        Loads the catalog from a file

        Parameters
        -----------
        json_file: str or pathlib.Path
            The path to the json file containing the catalog
        storage_options: dict
            fsspec parameters passed to the backend file-system such as Google Cloud Storage,
            Amazon Web Service S3.
        read_csv_kwargs: dict
            Additional keyword arguments passed through to the :py:func:`~pandas.read_csv` function.

        """
        storage_options = storage_options if storage_options is not None else {}
        read_csv_kwargs = read_csv_kwargs or {}
        _mapper = fsspec.get_mapper(json_file, **storage_options)

        with fsspec.open(json_file, **storage_options) as fobj:
            cat = cls.parse_raw(fobj.read())
            if cat.catalog_file:
                if _mapper.fs.exists(cat.catalog_file):
                    csv_path = cat.catalog_file
                else:
                    csv_path = f'{os.path.dirname(_mapper.root)}/{cat.catalog_file}'
                cat.catalog_file = csv_path
                df = pd.read_csv(
                    cat.catalog_file,
                    storage_options=storage_options,
                    **read_csv_kwargs,
                )
            else:
                df = pd.DataFrame(cat.catalog_dict)

            cat._df = df
            cat._cast_agg_columns_with_iterables()
            return cat

    @property
    def columns_with_iterables(self) -> typing.Set[str]:
        """Return a set of columns that have iterables."""
        if self._df.empty:
            return set()
        has_iterables = (
            self._df.sample(20, replace=True)
            .applymap(type)
            .isin([list, tuple, set])
            .any()
            .to_dict()
        )
        return {column for column, check in has_iterables.items() if check}

    @property
    def has_multiple_variable_assets(self) -> bool:
        """Return True if the catalog has multiple variable assets."""
        return self.aggregation_control.variable_column_name in self.columns_with_iterables

    @property
    def df(self) -> pd.DataFrame:
        """Return the dataframe."""
        return self._df

    @df.setter
    def df(self, value: pd.DataFrame) -> None:
        self._df = value

    def _cast_agg_columns_with_iterables(self) -> None:
        """Cast all agg_columns with iterables to tuple values so as
        to avoid hashing issues (e.g. TypeError: unhashable type: 'list')
        """
        columns = list(
            self.columns_with_iterables.intersection(
                set(map(lambda agg: agg.attribute_name, self.aggregation_control.aggregations))
            )
        )
        if columns:
            self._df[columns] = self._df[columns].apply(tuple)

    @property
    def grouped(self) -> typing.Union[pd.core.groupby.DataFrameGroupBy, pd.DataFrame]:
        if self.aggregation_control.groupby_attrs and set(
            self.aggregation_control.groupby_attrs
        ) != set(self.df.columns):
            return self.df.groupby(self.aggregation_control.groupby_attrs)
        return self.df

    def _construct_group_keys(
        self, sep: str = '.'
    ) -> typing.Dict[str, typing.Union[str, typing.Tuple[str]]]:
        grouped = self.grouped
        if isinstance(grouped, pd.core.groupby.generic.DataFrameGroupBy):
            internal_keys = grouped.groups.keys()
            public_keys = map(
                lambda key: key if isinstance(key, str) else sep.join(str(value) for value in key),
                internal_keys,
            )

        else:
            internal_keys = grouped.index
            public_keys = (
                grouped[grouped.columns.tolist()]
                .apply(lambda row: sep.join(str(v) for v in row), axis=1)
                .tolist()
            )

        return dict(zip(public_keys, internal_keys))

    def _unique(self) -> typing.Dict:
        def _find_unique(series):
            values = series.dropna()
            if series.name in self.columns_with_iterables:
                values = tlz.concat(values)
            return list(tlz.unique(values))

        data = self.df[self.df.columns]
        if data.empty:
            return {col: [] for col in self.df.columns}
        else:
            return data.apply(_find_unique, result_type='reduce').to_dict()

    def unique(self) -> pd.Series:
        """Return a series of unique values for each column in the catalog."""
        return pd.Series(self._unique())

    def nunique(self) -> pd.Series:
        """Return a series of the number of unique values for each column in the catalog."""
        return pd.Series(tlz.valmap(len, self._unique()))

    def search(
        self,
        *,
        query: typing.Union['QueryModel', typing.Dict[str, typing.Any]],
        require_all_on: typing.Union[str, typing.List[str]] = None,
    ) -> 'ESMCatalogModel':
        """
        Search for entries in the catalog.

        Parameters
        ----------
        query: dict, optional
            A dictionary of query parameters to execute against the dataframe.
        require_all_on : list, str, optional
            A dataframe column or a list of dataframe columns across
            which all entries must satisfy the query criteria.
            If None, return entries that fulfill any of the criteria specified
            in the query, by default None.

        Returns
        -------
        catalog: ESMCatalogModel
            A new catalog with the entries satisfying the query criteria.

        """

        if not isinstance(query, QueryModel):
            _query = QueryModel(
                query=query, require_all_on=require_all_on, columns=self.df.columns.tolist()
            )
        else:
            _query = query

        results = search(
            df=self.df, query=_query.query, columns_with_iterables=self.columns_with_iterables
        )
        if _query.require_all_on is not None and not results.empty:
            results = search_apply_require_all_on(
                df=results, query=_query.query, require_all_on=_query.require_all_on
            )
        return results
Пример #23
0
class Assembly(pydantic.BaseModel):
    """
    An Assembly models the environment and runtime configuration of a collection of Servos.

    Connectors are dynamically loaded via setuptools entry points
    (see https://packaging.python.org/specifications/entry-points/)
    and the settings class for a Servo instance must be created at
    runtime because the servo.yaml configuration file includes settings
    for an arbitrary number of connectors mounted onto arbitrary keys
    in the config file.

    The Assembly class is responsible for handling the connector
    loading and creating a concrete BaseServoConfiguration model that supports
    the connectors available and activated in the assembly. An assembly
    is the combination of configuration and associated code artifacts
    in an executable environment (e.g. a Docker image or a Python virtualenv
    running on your workstation.

    NOTE: The Assembly class overrides the Pydantic base class implementations
    of the schema family of methods. See the method docstrings for specific details.
    """

    config_file: Optional[pathlib.Path]
    servos: List[servo.servo.Servo]
    _context_token: Optional[contextvars.Token] = pydantic.PrivateAttr(None)

    @classmethod
    async def assemble(
        cls,
        *,
        config_file: Optional[pathlib.Path] = None,
        configs: Optional[List[Dict[str, Any]]] = None,
        optimizer: Optional[servo.configuration.Optimizer] = None,
        env: Optional[Dict[str, str]] = os.environ,
        **kwargs,
    ) -> "Assembly":
        """Assemble a Servo by processing configuration and building a dynamic settings model"""

        if config_file is None and configs is None:
            raise ValueError(
                f"cannot assemble with a config file and config objects")

        _discover_connectors()

        if config_file and not configs:
            # Build our Servo configuration from the config file + environment
            if not config_file.exists():
                raise FileNotFoundError(
                    f"config file '{config_file}' does not exist")

            configs = list(
                yaml.load_all(open(config_file), Loader=yaml.FullLoader))
            if not isinstance(configs, list):
                raise ValueError(
                    f'error: config file "{config_file}" parsed to an unexpected value of type "{configs.__class__}"'
                )

            # If we parsed an empty file, add an empty dict to work with
            if not configs:
                configs.append({})

        if len(configs) > 1 and optimizer is not None:
            raise ValueError(
                "cannot configure a multi-servo assembly with a single optimizer"
            )

        # Set up the event bus and pub/sub exchange
        pubsub_exchange = servo.pubsub.Exchange()
        servos: List[servo.servo.Servo] = []
        for config in configs:
            # TODO: Needs to be public / have a better name
            # TODO: We need to index the env vars here for multi-servo
            servo_config_model, routes = _create_config_model(config=config,
                                                              env=env)
            servo_config = servo_config_model.parse_obj(config)
            if not servo_config.optimizer:
                servo_config.optimizer = optimizer
            servo_optimizer = servo_config.optimizer or optimizer

            telemetry = servo.telemetry.Telemetry()

            # Initialize all active connectors
            connectors: List[servo.connector.BaseConnector] = []
            for name, connector_type in routes.items():
                connector_config = getattr(servo_config, name)
                if connector_config is not None:
                    connector_config.__optimizer__ = servo_optimizer
                    connector = connector_type(
                        name=name,
                        config=connector_config,
                        optimizer=servo_optimizer,
                        pubsub_exchange=pubsub_exchange,
                        telemetry=telemetry,
                        __optimizer__=servo_optimizer,
                        __connectors__=connectors,
                    )
                    connectors.append(connector)

            # Build the servo object
            servo_ = servo.servo.Servo(
                config=servo_config,
                connectors=connectors.copy(
                ),  # Avoid self-referential reference to servo
                optimizer=servo_optimizer,
                telemetry=telemetry,
                __connectors__=connectors,
                pubsub_exchange=pubsub_exchange,
            )
            connectors.append(servo_)
            servos.append(servo_)

        assembly = cls(
            config_file=config_file,
            servos=servos,
        )

        # Attach all connectors to the servo
        await asyncio.gather(*list(
            map(lambda s: s.dispatch_event(servo.servo.Events.attach, s),
                servos)))

        return assembly

    def __init__(self, *args, servos: List[servo.Servo], **kwargs) -> None:
        super().__init__(*args, servos=servos, **kwargs)

        # Ensure object is shared by identity
        self.servos = servos

    ##
    # Utility functions

    async def dispatch_event(
        self,
        event: Union[servo.events.Event, str],
        *args,
        first: bool = False,
        include: Optional[List[str]] = None,
        exclude: Optional[List[str]] = None,
        prepositions: servo.events.Preposition = (
            servo.events.Preposition.before
            | servo.events.Preposition.on
            | servo.events.Preposition.after),
        return_exceptions: bool = False,
        **kwargs,
    ) -> Union[Optional[servo.events.EventResult],
               List[servo.events.EventResult]]:
        """Dispatch an event to all servos active in the assembly."""

        group = asyncio.gather(
            *list(
                map(
                    lambda s: s.dispatch_event(
                        event,
                        *args,
                        first=first,
                        include=include,
                        exclude=exclude,
                        prepositions=prepositions,
                        **kwargs,
                    ),
                    self.servos,
                )),
            return_exceptions=return_exceptions,
        )
        results = await group
        if results:
            results = functools.reduce(lambda x, y: x + y, results)

        # TODO: This needs to be tested in multi-servo
        if first:
            return results[0] if results else None

        return results

    @classmethod
    def all_connector_types(cls) -> Set[Type[servo.connector.BaseConnector]]:
        """Return a set of all connector types in the assembly excluding the Servo"""
        return servo.connector._connector_subclasses.copy()

    async def add_servo(self, servo_: servo.servo.Servo) -> None:
        """Add a servo to the assembly.

        Once added, the servo is sent the startup event to prepare for execution.

        Args:
            servo_: The servo to add to the assembly.
        """
        self.servos.append(servo_)

        await servo.attach()

        if self.is_running:
            await servo.startup()

    async def remove_servo(self, servo_: servo.servo.Servo) -> None:
        """Remove a servo from the assembly.

        Before removal, the servo is sent the detach event to prepare for
        eviction from the assembly.

        Args:
            servo_: The servo to remove from the assembly.
        """

        await servo.detach()

        if self.is_running:
            await servo.shutdown()

        self.servos.remove(servo_)

    async def startup(self):
        """Notify all servos that the assembly is starting up."""
        await asyncio.gather(*list(map(
            lambda s: s.startup(),
            self.servos,
        )))

    async def shutdown(self):
        """Notify all servos that the assembly is shutting down."""
        await asyncio.gather(*list(
            map(
                lambda s: s.shutdown(),
                filter(
                    lambda s: s.is_running,
                    self.servos,
                ),
            )))
Пример #24
0
class AddJwksResourceUrlToApp(pydantic.BaseModel):
    organization: typing.Literal["nhsd-nonprod", "nhsd-prod"]
    environment: typing.Literal["internal-dev", "internal-dev-sandbox",
                                "internal-qa", "internal-qa-sandbox", "ref",
                                "dev", "int", "sandbox", "prod", ]
    access_token: str
    app_id: pydantic.UUID4
    jwks_resource_url: pydantic.HttpUrl = pydantic.Field(
        default_factory=default_jwks_resource_url)
    _app_data: typing.Dict = pydantic.PrivateAttr(
        default_factory=_put_app_data)

    @pydantic.validator("environment")
    def check_org_env_combo(cls, environment, values):
        org = values.get("organization")
        if org is None:
            return
        non_prod_envs = [
            "internal-dev",
            "internal-dev-sandbox",
            "internal-qa",
            "internal-qa-sandbox",
            "ref",
        ]
        if org == "nhsd-nonprod" and environment not in non_prod_envs:
            raise ValueError(
                f"Invalid environment {environment} for organization {org}")
        return environment

    @pydantic.validator("environment")
    def cache_put(cls, environment):
        global _environment
        _environment = environment
        return environment

    @pydantic.validator("app_id")
    def check_app_exists(cls, app_id, values):
        access_token = values.get("access_token")
        org = values.get("organization")
        url = f"{APIGEE_BASE_URL}organizations/{org}/apps/{app_id}"
        app_response = utils.get(url, access_token)

        if app_response.get("failed"):
            raise ValueError(
                f"Unable to find app with app_id {app_id} in {org}")

        app_data = app_response["response"]["body"]
        attributes = app_data.get("attributes", [])
        jwks_attribs = [
            a for a in attributes if a["name"] == "jwks-resource-url"
        ]
        if len(jwks_attribs) > 1:
            raise ValueError(
                f"App {app_id} has {len(jwks_attribs)} jwks-resource-url attributes! {[v['value'] for v in jwks_attribs]}"
            )

        # cache response data
        global _cached_app_data
        _cached_app_data = app_data
        global _app_id
        _app_id = app_id
        return app_id

    @pydantic.validator("jwks_resource_url", always=True)
    def check_jwks_url(cls, jwks_resource_url):
        skip_validation = os.environ.get("SKIP_JWKS_RESOURCE_URL_VALIDATION")
        if skip_validation:
            return jwks_resource_url

        resp = requests.get(jwks_resource_url)
        if resp.status_code != 200:
            raise ValueError(
                f"Invalid jwks_resource_url: GET {jwks_resource_url} returned {resp.status_code}"
            )
        try:
            resp.json()
        except Exception:
            raise ValueError(
                f"Invalid jwks_resource_url: GET {jwks_resource_url} returned {resp.content.decode()}, which is not valid JSON"
            )
        return jwks_resource_url