class LookerConnectionDefinition(ConfigModel): platform: str default_db: str default_schema: Optional[ str] # Optional since some sources are two-level only platform_instance: Optional[str] = None platform_env: Optional[str] = Field( default=None, description= "The environment that the platform is located in. Leaving this empty will inherit defaults from the top level Looker configuration", ) @validator("platform_env") def platform_env_must_be_one_of(cls, v: str) -> str: return EnvBasedSourceConfigBase.env_must_be_one_of(v) @validator("platform", "default_db", "default_schema") def lower_everything(cls, v): """We lower case all strings passed in to avoid casing issues later""" if v is not None: return v.lower() @classmethod def from_looker_connection( cls, looker_connection: DBConnection) -> "LookerConnectionDefinition": """Dialect definitions are here: https://docs.looker.com/setup-and-management/database-config""" extractors: Dict[str, Any] = { "^bigquery": _get_bigquery_definition, ".*": _get_generic_definition, } if looker_connection.dialect_name is not None: for extractor_pattern, extracting_function in extractors.items(): if re.match(extractor_pattern, looker_connection.dialect_name): (platform, db, schema) = extracting_function(looker_connection) return cls(platform=platform, default_db=db, default_schema=schema) raise ConfigurationError( f"Could not find an appropriate platform for looker_connection: {looker_connection.name} with dialect: {looker_connection.dialect_name}" ) else: raise ConfigurationError( f"Unable to fetch a fully filled out connection for {looker_connection.name}. Please check your API permissions." )
class Link(BaseModel): """Represents a link from the topology link file.""" id: str = Field(alias="id") name: Optional[str] = Field(alias="nm") branchid: int = Field(alias="ri") modellinktype: int = Field(alias="mt") branchtype: int = Field(alias="bt") objectid: str = Field(alias="ObID") beginnode: str = Field(alias="bn") endnode: str = Field(alias="en") def _get_identifier(self, data: dict) -> Optional[str]: return data.get("id") or data.get("nm") def dict(self, *args, **kwargs): kwargs["by_alias"] = True return super().dict(*args, **kwargs)
def create_cloned_field(field: Field) -> Field: original_type = field.type_ if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"): original_type = original_type.__pydantic_model__ # type: ignore use_type = original_type if lenient_issubclass(original_type, BaseModel): original_type = cast(Type[BaseModel], original_type) use_type = create_model( # type: ignore original_type.__name__, __config__=original_type.__config__, __validators__=original_type.__validators__, ) for f in original_type.__fields__.values(): use_type.__fields__[f.name] = f new_field = Field( name=field.name, type_=use_type, class_validators={}, default=None, required=False, model_config=BaseConfig, schema=Schema(None), ) new_field.has_alias = field.has_alias new_field.alias = field.alias new_field.class_validators = field.class_validators new_field.default = field.default new_field.required = field.required new_field.model_config = field.model_config new_field.schema = field.schema new_field.allow_none = field.allow_none new_field.validate_always = field.validate_always if field.sub_fields: new_field.sub_fields = [ create_cloned_field(sub_field) for sub_field in field.sub_fields ] if field.key_field: new_field.key_field = create_cloned_field(field.key_field) new_field.validators = field.validators new_field.whole_pre_validators = field.whole_pre_validators new_field.whole_post_validators = field.whole_post_validators new_field.parse_json = field.parse_json new_field.shape = field.shape new_field._populate_validators() return new_field
class MyConfig(Config): DEBUG: bool HOST: str PORT: int ALIAS_HOST: str = Field(alias="HOST") REDIS_ADDRESS: str REDIS_PASS: Optional[str] = None MYSQL_DB_HOST: str MYSQL_DB_NAME: str MYSQL_DB_PASS: str MYSQL_DB_USER: str ES_HOST: List[str] = Json.i() TEST_LIST_INT: List[int] = Json.i() YML_ES_HOST: Optional[List[str]] = None YML_TEST_LIST_INT: Optional[List[int]] = None
class CustomSkinLoaderApi(BaseModel): class Skins(BaseModel): slim: Optional[str] default: Optional[str] username: Optional[str] skins: Optional[Skins] skin_hash: Optional[str] cape_hash: Optional[str] = Field(..., alias='cape') player_existed: bool = True skin_type: Optional[Literal['default', 'slim', None]] = None cape_existed: bool = True @root_validator(pre=True) def pre_processor(cls, values: dict): # player_existed = bool(values) if not player_existed: skin_type = None cape_existed = False else: skin_type = 'slim' if 'slim' in values[ 'skins'] else 'default' if values['skins']['default'] else None cape_existed = 'cape' in values and bool(values['cape']) # parse skin hash if skin_type == 'default': skin_hash = values['skins']['default'] elif skin_type == 'slim': skin_hash = values['skins']['slim'] else: skin_hash = None values.update({ 'player_existed': player_existed, 'skin_type': skin_type, 'cape_existed': cape_existed, 'skin_hash': skin_hash }) return values @classmethod async def get(cls, api_root: str, username: str): async with aiohttp.ClientSession() as session: async with session.get(f'{api_root}/{username}.json') as resp: return cls.parse_raw(await resp.text())
class ArkitektModel(BaseModel, metaclass=ArkitektModelMeta): TYPENAME: str = Field(None, alias='__typename') id: Optional[int] @classmethod def get_ward(cls): try: identifier = cls.getMeta().identifier except Exception as e: raise ArnheimModelConfigurationError( f"Make soure your Model {cls.__name__}overwrites Meta identifier: {e}" ) from bergen.registries.ward import get_ward_registry return get_ward_registry().get_ward(identifier=identifier) @classmethod def getMeta(cls): return cls.Meta def __repr__(self) -> str: from pprint import pformat return pformat(self.__dict__, indent=4, width=1) def _repr_html_(self): def buildTable(attributes: Dict): tablestring = "<table>" for key, value in attributes.items(): tablestring = tablestring + (f""" <tr> <td>{key.capitalize()}</td> <td>{value}</td> </tr> """) return tablestring + "</table>" return f""" <p> Instance of {self.__class__.__name__} <p> {buildTable(self.__dict__)} """ def __setattr__(self, attr, value): if attr in self.__slots__: object.__setattr__(self, attr, value) else: super().__setattr__(attr, value)
def get_body_field(*, dependant: Dependant, name: str) -> Optional[Field]: flat_dependant = get_flat_dependant(dependant) if not flat_dependant.body_params: return None first_param = flat_dependant.body_params[0] embed = getattr(first_param.schema, "embed", None) if len(flat_dependant.body_params) == 1 and not embed: return get_schema_compatible_field(field=first_param) model_name = "Body_" + name BodyModel = create_model(model_name) for f in flat_dependant.body_params: BodyModel.__fields__[f.name] = get_schema_compatible_field(field=f) required = any(True for f in flat_dependant.body_params if f.required) BodySchema_kwargs: Dict[str, Any] = dict(default=None) if any( isinstance(f.schema, params.File) for f in flat_dependant.body_params): BodySchema: Type[params.Body] = params.File elif any( isinstance(f.schema, params.Form) for f in flat_dependant.body_params): BodySchema = params.Form else: BodySchema = params.Body body_param_media_types = [ getattr(f.schema, "media_type") for f in flat_dependant.body_params if isinstance(f.schema, params.Body) ] if len(set(body_param_media_types)) == 1: BodySchema_kwargs["media_type"] = body_param_media_types[0] field = Field( name="body", type_=BodyModel, default=None, required=required, model_config=BaseConfig, class_validators={}, alias="body", schema=BodySchema(**BodySchema_kwargs), ) return field
class GroupById(BaseEndpoint): group_id: GroupId fields_: Optional[Sequence[Literal[ContactsOptionalFields.METADATA]]] = Field(None, alias="fields") @property def endpoint_data(self) -> EndpointData: return EndpointData( method="GET", url=f"/groups/{self.group_id}", query_params=self._query_params, ) @property def _query_params(self) -> WrApiQueryParams: params = WrApiQueryParams() if self.fields_: params["fields"] = self._convert_seq(self.fields_) return params
class AttachmentsById(BaseEndpoint): task_ids: Sequence[TaskId] = Field(..., max_length=100) versions: Optional[StrictBool] @property def endpoint_data(self) -> EndpointData: return EndpointData( method="GET", url=f"/attachments/{','.join(self.task_ids)}", query_params=self._query_params, ) @property def _query_params(self) -> WrApiQueryParams: params = WrApiQueryParams() if self.versions is not None: params["versions"] = self._convert_bool(self.versions) return params
class CourseSectionPeriod(BaseModel): semester_id: str = Field(example="202101") crn: str = Field(example="42608") type: Optional[ClassTypeEnum] = Field(None, example=ClassTypeEnum.LECTURE) start_time: Optional[str] = Field( None, description="24-hour 0-padded start time hh:mm format (RPI time)", example="14:00", ) end_time: Optional[str] = Field( None, description="24-hour 0-padded end time hh:mm format (RPI time)", example="15:50", ) instructors: List[str] = Field(description="Last names of instructor(s)", example=["Hanna", "Shablovsky"]) location: Optional[str] = Field( description="Location of class (null if not yet determined or online)", example="SAGE 114", ) days: List[int] = Field(description="Days of week period meets (0-Sunday)", example=[1, 4]) @staticmethod def from_record(record: Dict[str, Any]): """Creates a CourseSectionPeriod from a DB record.""" return CourseSectionPeriod(**record) def to_record(self) -> Dict[str, Any]: """Convert period to flat dictionary to store in DB.""" return { **self.dict(), } def __str__(self) -> str: return f"{self.type} on days {self.days} from {self.start_time}-{self.end_time} with {self.instructors} at {self.location}" class Config: use_enum_values = True
class VADAnnotation(BaseModel): frames_count: conint(gt=0) is_anomalous_regions_available: bool is_anomaly_track_id_available: bool video_length_sec: Optional[PositiveFloat] = None frame_width: conint(gt=0) frame_height: conint(gt=0) frame_rate: Optional[PositiveFloat] = None frames: List[VADFrame] = Field(..., description=("len(frames) == frames_count")) @validator('frames') def frames_len(cls, v, values, **kwargs): if len(v) != values["frames_count"]: raise ValidationError( "Length of 'frames' does not match 'frames_count'") return v class Config: extra = "allow"
class SpaceById(BaseEndpoint): space_id: SpaceId fields_: Optional[Sequence[SpaceOptionalFields]] = Field(None, alias="fields") @property def endpoint_data(self) -> EndpointData: return EndpointData( method="GET", url=f"/spaces/{self.space_id}", query_params=self._query_params, ) @property def _query_params(self) -> WrApiQueryParams: params = WrApiQueryParams() if self.fields_: params["fields"] = self._convert_seq(self.fields_) return params
class LinkFile(FileModel): """Represents the file with the RR link topology data.""" _parser = NetworkTopologyFileParser(enclosing_tag="brch") link: List[Link] = Field([], alias="brch") @classmethod def _ext(cls) -> str: return ".tp" @classmethod def _filename(cls) -> str: return "3b_link" @classmethod def _get_serializer(cls) -> Callable: return LinkFileSerializer.serialize @classmethod def _get_parser(cls) -> Callable: return cls._parser.parse
class NodeFile(FileModel): """Represents the file with the RR node topology data.""" _parser = NetworkTopologyFileParser(enclosing_tag="node") node: List[Node] = Field([], alias="node") @classmethod def _ext(cls) -> str: return ".tp" @classmethod def _filename(cls) -> str: return "3b_nod" @classmethod def _get_serializer(cls) -> Callable: return NodeFileSerializer.serialize @classmethod def _get_parser(cls) -> Callable: return cls._parser.parse
class RollingUpdate(BaseModel): """ This model is used for DaemonSets and Deployments. Although the models are identical, the underlying strategies differ for the two. For a Deployment, max_unavailable refers to how much the old ReplicaSet can be scaled down. Thus, max_unavailable refers to the maximum number of Pods that may be unavailable during the update. max_unavailable for a DaemonSet refers to the number of Nodes that should be running the daemon Pod (despite what is mentioned in the docs). If the number of Nodes with unavailable daemon Pods reaches max_unavailable, then Kubernetes will not stop Pods on other Nodes in order to update them. The same distinction applies to max_surge. The documentation claims, that only one DaemonSet Pod is created, but as of v1.21, max_surge allows a second Pod to be scheduled for the duration of the update. """ type_: Literal["RollingUpdate"] = Field("RollingUpdate", const=True) max_surge: str # This field was introduced in Kubernetes v1.21. max_unavailable: str
class OfferVenueResponse(BaseModel): @classmethod def from_orm(cls, venue): # type: ignore venue.coordinates = { "latitude": venue.latitude, "longitude": venue.longitude } return super().from_orm(venue) id: int address: Optional[str] city: Optional[str] managingOfferer: OfferOffererResponse = Field(..., alias="offerer") name: str postalCode: Optional[str] publicName: Optional[str] coordinates: Coordinates class Config: orm_mode = True allow_population_by_field_name = True
class UserProfileResponse(BaseModel): id: int dateOfBirth: Optional[datetime.datetime] deposit_expiration_date: Optional[datetime.datetime] deposit_version: Optional[int] email: str expenses: List[Expense] firstName: Optional[str] hasAllowedRecommendations: bool is_eligible: bool lastName: Optional[str] isBeneficiary: bool phoneNumber: Optional[str] publicName: Optional[str] = Field(None, alias="pseudo") needsToFillCulturalSurvey: bool show_eligible_card: bool class Config: orm_mode = True alias_generator = to_camel allow_population_by_field_name = True @validator("publicName", pre=True) def format_public_name(cls, publicName: str) -> Optional[str]: # pylint: disable=no-self-argument return publicName if publicName != VOID_PUBLIC_NAME else None @validator("firstName", pre=True) def format_first_name(cls, firstName: Optional[str]) -> Optional[str]: # pylint: disable=no-self-argument return firstName if firstName != VOID_FIRST_NAME else None @staticmethod def _show_eligible_card(user: User) -> bool: return (relativedelta(user.dateCreated, user.dateOfBirth).years < users_constants.ELIGIBILITY_AGE and user.isBeneficiary is False and user.is_eligible) @classmethod def from_orm(cls, user): user.show_eligible_card = cls._show_eligible_card(user) return super().from_orm(user)
class EnvBasedSourceConfigBase(ConfigModel): """ Any source that produces dataset urns in a single environment should inherit this class """ env: str = Field( default=FabricTypeClass.PROD, description= "The environment that all assets produced by this connector belong to", ) @validator("env") def env_must_be_one_of(cls, v: str) -> str: # Get all the constants from the FabricTypeClass. It's not an enum, so this is a bit hacky but works allowed_envs = [ value for name, value in vars(FabricTypeClass).items() if not name.startswith("_") ] if (v.upper()) not in allowed_envs: raise ConfigurationError( f"env must be one of {allowed_envs}, found {v}") return v.upper()
class SupersetConfig(ConfigModel): # See the Superset /security/login endpoint for details # https://superset.apache.org/docs/rest-api connect_uri: str = Field(default="localhost:8088", description="Superset host URL.") username: Optional[str] = Field(default=None, description="Superset username.") password: Optional[str] = Field(default=None, description="Superset password.") provider: str = Field(default="db", description="Superset provider.") options: Dict = Field(default={}, description="") env: str = Field( default=DEFAULT_ENV, description="Environment to use in namespace when constructing URNs", ) database_alias: Dict[str, str] = Field( default={}, description="Can be used to change mapping for database names in superset to what you have in datahub", ) @validator("connect_uri") def remove_trailing_slash(cls, v): return config_clean.remove_trailing_slashes(v)
class OfferVenueResponse(BaseModel): @classmethod def from_orm(cls, venue): # type: ignore venue.coordinates = {"latitude": venue.latitude, "longitude": venue.longitude} result = super().from_orm(venue) # FIXME: remove this line once Venue.isPermanent is not nullable result.isPermanent = result.isPermanent or False return result id: int address: Optional[str] city: Optional[str] managingOfferer: OfferOffererResponse = Field(..., alias="offerer") name: str postalCode: Optional[str] publicName: Optional[str] coordinates: Coordinates isPermanent: bool class Config: orm_mode = True allow_population_by_field_name = True
class DruidConfig(BasicSQLAlchemyConfig): # defaults scheme = "druid" schema_pattern: AllowDenyPattern = Field( default=AllowDenyPattern(deny=["^(lookup|sys).*"]), description="regex patterns for schemas to filter in ingestion.", ) def get_sql_alchemy_url(self): return f"{super().get_sql_alchemy_url()}/druid/v2/sql/" """ The pydruid library already formats the table name correctly, so we do not need to use the schema name when constructing the URN. Without this override, every URN would incorrectly start with "druid. For more information, see https://druid.apache.org/docs/latest/querying/sql.html#schemata-table """ def get_identifier(self, schema: str, table: str) -> str: return (f"{self.platform_instance}.{table}" if self.platform_instance else f"{table}")
class LDAPSourceConfig(ConfigModel): """Config used by the LDAP Source.""" # Server configuration. ldap_server: str = Field(description="LDAP server URL.") ldap_user: str = Field(description="LDAP user.") ldap_password: str = Field(description="LDAP password.") # Extraction configuration. base_dn: str = Field(description="LDAP DN.") filter: str = Field(default="(objectClass=*)", description="LDAP extractor filter.") # If set to true, any users without first and last names will be dropped. drop_missing_first_last_name: bool = Field( default=True, description="If set to true, any users without first and last names will be dropped.", ) page_size: int = Field( default=20, description="Size of each page to fetch when extracting metadata." )
class _WorkSchedulesBaseEndpoint(BaseEndpoint): fields_: Optional[Sequence[WorkScheduleFields]] = Field(None, alias="fields") @property def endpoint_data(self) -> EndpointData: return EndpointData( method="GET", url=self._url, query_params=self._query_params, ) @property def _query_params(self) -> WrApiQueryParams: params = WrApiQueryParams() if self.fields_: params["fields"] = self._convert_seq(self.fields_) return params @property def _url(self) -> str: raise NotImplementedError()
class ModifyTasksById(BaseEndpoint): task_ids: Sequence[TaskId] = Field(..., max_length=100) custom_fields: Optional[Sequence[CustomField]] effort_allocation: Optional[TaskEffort] @property def endpoint_data(self) -> EndpointData: return EndpointData( method="PUT", url=f"/tasks/{','.join(self.task_ids)}", body_params=self._body_params, ) @property def _body_params(self) -> BodyParams: params = {} if self.custom_fields is not None: params["customFields"] = self._convert_input_seq(self.custom_fields) if self.effort_allocation: params["effortAllocation"] = self._convert_input(self.effort_allocation) return params