def test_validator_stacking(self): # See: https://github.com/lovasoa/marshmallow_dataclass/issues/91 class SimpleValidator(Validator): # Marshmallow checks for valid validators at construction time only using `callable` def __call__(self): pass validator_a = SimpleValidator() validator_b = SimpleValidator() validator_c = SimpleValidator() validator_d = SimpleValidator() CustomTypeOneValidator = NewType("CustomTypeOneValidator", str, validate=validator_a) CustomTypeNoneValidator = NewType("CustomTypeNoneValidator", str, validate=None) CustomTypeMultiValidator = NewType("CustomTypeNoneValidator", str, validate=[validator_a, validator_b]) @dataclasses.dataclass class A: data: CustomTypeNoneValidator = dataclasses.field() schema_a = class_schema(A)() self.assertListEqual(schema_a.fields["data"].validators, []) @dataclasses.dataclass class B: data: CustomTypeNoneValidator = dataclasses.field( metadata={"validate": validator_a}) schema_b = class_schema(B)() self.assertListEqual(schema_b.fields["data"].validators, [validator_a]) @dataclasses.dataclass class C: data: CustomTypeNoneValidator = dataclasses.field( metadata={"validate": [validator_a, validator_b]}) schema_c = class_schema(C)() self.assertListEqual(schema_c.fields["data"].validators, [validator_a, validator_b]) @dataclasses.dataclass class D: data: CustomTypeOneValidator = dataclasses.field() schema_d = class_schema(D)() self.assertListEqual(schema_d.fields["data"].validators, [validator_a]) @dataclasses.dataclass class E: data: CustomTypeOneValidator = dataclasses.field( metadata={"validate": validator_b}) schema_e = class_schema(E)() self.assertListEqual(schema_e.fields["data"].validators, [validator_a, validator_b]) @dataclasses.dataclass class F: data: CustomTypeOneValidator = dataclasses.field( metadata={"validate": [validator_b, validator_c]}) schema_f = class_schema(F)() self.assertListEqual(schema_f.fields["data"].validators, [validator_a, validator_b, validator_c]) @dataclasses.dataclass class G: data: CustomTypeMultiValidator = dataclasses.field() schema_g = class_schema(G)() self.assertListEqual(schema_g.fields["data"].validators, [validator_a, validator_b]) @dataclasses.dataclass class H: data: CustomTypeMultiValidator = dataclasses.field( metadata={"validate": validator_c}) schema_h = class_schema(H)() self.assertListEqual(schema_h.fields["data"].validators, [validator_a, validator_b, validator_c]) @dataclasses.dataclass class J: data: CustomTypeMultiValidator = dataclasses.field( metadata={"validate": [validator_c, validator_d]}) schema_j = class_schema(J)() self.assertListEqual( schema_j.fields["data"].validators, [validator_a, validator_b, validator_c, validator_d], )
class SummaryGrouping: BEGIN = 'begin' SECTOR = 'sector' TECHNOLOGY = 'technology' TECHNOLOGY_CODE = 'technologyCode' FUEL_CODE = 'fuelCode' # -- Forecast ---------------------------------------------------------------- ForecastDuration = NewType( name='ForecastDuration', typ=int, field=fields.Function, deserialize=lambda s: isodate.parse_duration(s), serialize=lambda f: isodate.duration_isoformat(f.resolution), ) ForecastBeginList = NewType( name='ForecastBeginList', typ=List[datetime], field=fields.Function, serialize=lambda f: [begin.isoformat() for begin in f.get_begins()], ) ForecastEndList = NewType( name='ForecastEndList', typ=List[datetime], field=fields.Function,
sub: str # -- Login request and response -------------------------------------------- @dataclass class LoginRequest: return_url: str = field(metadata=dict(data_key='returnUrl')) # -- VerifyLoginCallback request and response -------------------------------- Scope = NewType( name='Scope', typ=List[str], field=fields.Function, deserialize=lambda scope: scope.split(' '), ) @dataclass class VerifyLoginCallbackRequest: scope: Scope code: str state: str # -- GetAccounts request and response ---------------------------------------- @dataclass
When we serialize a UNIX timestamp, return it as a float. Without doing this, we'd return UNIX timestamps as strings, which isn't quite what we want. """ result = super()._serialize(value, attr, obj, **kwargs) data_format = self.format or self.DEFAULT_FORMAT if data_format == 'timestamp': return float(result) return result # A field that serializes datetime objects as UNIX timestamps. TimestampField = NewType("TimestampField", datetime.datetime, field=DateTime, format="timestamp") class MetricUnit(Enum): """Supported measurement metrics.""" WH_GENERATED = "whG" WH_USED = "whU" TEMP_CELSIUS = "tempC" class GeoUnit(Enum): """Geographic units available for geo queries.""" M = "m" KM = "km" MI = "mi"
value: typing.Any, attr: str = None, data: typing.Mapping[str, typing.Any] = None, **kwargs): if isinstance(value, dt): return value return super().deserialize(value, attr, data, **kwargs) class DateTimeValidator(Validator): default_message = "Not a valid datetime" def _format_error(self, value) -> str: value_text = f"{value} error " return super()._format_error(value_text) def __call__(self, value) -> dt: try: if type(value) is not dt: raise ValidationError(self._format_error(value)) except TypeError as error: raise ValidationError(self._format_error(value)) from error return value DataclassDateTime = NewType("NewDateTime", dt, field=PyDateTimeField, validate=DateTimeValidator)
# pylint: disable=invalid-name from __future__ import absolute_import import uuid from dataclasses import field from typing import Optional, List, ClassVar, Type, Dict, Any from datetime import datetime from collections.abc import Hashable from marshmallow import Schema, fields, post_dump from marshmallow_dataclass import dataclass, NewType, add_schema from jwthenticator.consts import JWT_ALGORITHM, JWT_ALGORITHM_FAMILY # Define the UUID type that uses Marshmallow's UUID + Python's UUID UUID = NewType("UUID", uuid.UUID, field=fields.UUID) # Marshmallow base schema for skipping None values on dump class BaseSchema(Schema): SKIP_VALUES = {None} @post_dump # pylint: disable=unused-argument def remove_skip_values(self, data: Any, many: bool) -> Dict[Any, Any]: return { key: value for key, value in data.items() if not isinstance(value, Hashable) or value not in self.SKIP_VALUES }
from eduid_scimapi.db.eventdb import EventLevel from eduid_scimapi.schemas.scimbase import ( BaseCreateRequest, BaseResponse, BaseSchema, DateTimeField, SCIMResourceType, SCIMSchema, ) __author__ = 'ft' SCIMResourceTypeValue = NewType('SCIMResourceTypeValue', SCIMResourceType, field=EnumField, enum=SCIMResourceType, by_value=True) @dataclass class NutidEventResource: resource_type: SCIMResourceTypeValue = field(metadata={ 'data_key': 'resourceType', 'required': True }) scim_id: UUID = field(metadata={'data_key': 'id', 'required': True}) external_id: Optional[str] = field(default=None, metadata={ 'data_key': 'externalId', 'required': False
class City: name: Optional[str] contact: NewType("Email", str, field=marshmallow.fields.Email) geography: List[Geography] = field(default_factory=list)
_version = r'\d+\.\d+(\.\d+[\w-]*)*' CONDITION_VERSION_PATTERN = rf'^\^{_version}$' VERSION_PATTERN = f'^{_version}$' BRANCH_PATTERN = f'{VERSION_PATTERN}|^master$' INTERVAL_PATTERN = r'\d+[mshd]' TACTIC_URL = r'https://attack.mitre.org/tactics/TA[0-9]+/' TECHNIQUE_URL = r'https://attack.mitre.org/techniques/T[0-9]+/' SUBTECHNIQUE_URL = r'https://attack.mitre.org/techniques/T[0-9]+/[0-9]+/' MACHINE_LEARNING = 'machine_learning' SAVED_QUERY = 'saved_query' QUERY = 'query' OPERATORS = ['equals'] CodeString = NewType("CodeString", str) ConditionSemVer = NewType('ConditionSemVer', str, validate=validate.Regexp(CONDITION_VERSION_PATTERN)) Date = NewType('Date', str, validate=validate.Regexp(DATE_PATTERN)) FilterLanguages = Literal["kuery", "lucene"] Interval = NewType('Interval', str, validate=validate.Regexp(INTERVAL_PATTERN)) PositiveInteger = NewType('PositiveInteger', int, validate=validate.Range(min=1)) Markdown = NewType("MarkdownField", CodeString) Maturity = Literal['development', 'experimental', 'beta', 'production', 'deprecated'] MaxSignals = NewType("MaxSignals", int, validate=validate.Range(min=1)) NonEmptyStr = NewType('NonEmptyStr', str, validate=validate.Length(min=1)) Operator = Literal['equals']
""" __tablename__ = 'ggo_technology' __table_args__ = (sa.UniqueConstraint('technology_code', 'fuel_code'), ) id = sa.Column(sa.Integer(), primary_key=True, index=True) technology = sa.Column(sa.String(), nullable=False) technology_code = sa.Column(sa.String(), index=True, nullable=False) fuel_code = sa.Column(sa.String(), index=True, nullable=False) # -- Common ------------------------------------------------------------------ GgoTechnology = NewType( name='GgoTechnology', typ=str, field=marshmallow.fields.Function, serialize=lambda ggo: ggo.technology.technology if ggo.technology else None, ) @dataclass class MappedGgo: """ A reflection of the Ggo class above, but supports JSON schema serialization/deserialization using marshmallow/marshmallow-dataclass. """ address: str sector: str begin: datetime end: datetime
super().__init__(*args, **kwargs) def _serialize(self, value: pathlib.Path, *args: T.Any, **kwargs: T.Any) -> T.Optional[str]: if value is None: return None return str(value) def _deserialize(self, value: str, *args: T.Any, **kwargs: T.Any) -> T.Optional[pathlib.Path]: if value is None: return None return pathlib.Path(value) Path = NewType("Path", object, field=PathField) @add_schema @dataclass class Metadata: """For each Metadta JSON file, what can we expect to find""" json_path: Path name: str alt_params: T.List[AltParams] = dataclasses.field( repr=False, default_factory=list, metadata={"data_key": "alternatives"}) url: str = dataclasses.field(repr=False, default="") tags: T.List[str] = dataclasses.field(repr=False, default_factory=list)
def validate_iso8601_duration(d): if d.total_seconds() <= 0: raise ValidationError('Duration too short') def deserialize_iso8601_duration(s): try: return isodate.parse_duration(s) except isodate.isoerror.ISO8601Error: raise ValidationError('Invalid ISO8601 duration') ForecastDuration = NewType( name='ForecastDuration', typ=int, field=fields.Function, deserialize=deserialize_iso8601_duration, serialize=lambda f: isodate.duration_isoformat(timedelta(seconds=f.resolution)), validate=validate_iso8601_duration, ) ForecastSender = NewType( name='ForecastSender', typ=int, field=fields.Function, serialize=lambda f: f.user.sub, ) ForecastRecipient = NewType( name='ForecastRecipient',
if not facility.public_id: facility.public_id = str(uuid4()) if not facility.name: facility.name = '%s, %s %s' % ( facility.address, facility.postcode, facility.city_name, ) # -- Common ------------------------------------------------------------------ FacilityTechnology = NewType( name='FacilityTechnology', typ=str, field=marshmallow.fields.Function, serialize=lambda facility: (facility.technology.technology if facility.technology else UNKNOWN_TECHNOLOGY_LABEL), ) FacilityTagList = NewType( name='FacilityTag', typ=str, field=marshmallow.fields.Function, serialize=lambda facility: [t.tag for t in facility.tags], ) @dataclass class MappedFacility: """
return "" value = value.strip().lower() if not value.startswith("--"): prefix = "" for idx in range(2): if value[idx] != "-": prefix += "-" value = f"{prefix}{value}" assert value in { "", "--edge", "--classic", }, f"Unknown SNAP policy: {value}" return value SnapPolicy = NewType("SnapPolicy", str, field=SnapPolicyField) @add_schema @dataclass class SnapTemplateParams: """JSON Parameters required to fill in the Snap templates""" pkg: str = field(default="") policy: SnapPolicy = field(default="") Schema: T.ClassVar[T.Type[ma.Schema]] = ma.Schema class Meta: ordered = True
import marshmallow from typing import List from itertools import groupby from datetime import datetime, timezone from dataclasses import dataclass from marshmallow_dataclass import NewType from origin.common import EmissionValues EmissionValuesType = NewType( name='EmissionValuesType', typ=dict, field=marshmallow.fields.Function, deserialize=lambda emissions: EmissionValues(**emissions), ) @dataclass class EmissionPart: technology: str # Consumed amount in Wh amount: int # Emissions in gram emissions: EmissionValuesType @dataclass class EmissionData: sector: str
@dataclass class GetOnboadingUrlRequest: return_url: str = field(metadata=dict(data_key='returnUrl')) @dataclass class GetOnboadingUrlResponse: success: bool url: str # -- GetDisclosure request and response -------------------------------------- MeasurementValue = NewType('MeasurementValue', int, allow_none=True) @dataclass class DisclosureState(Enum): PENDING = 'PENDING' PROCESSING = 'PROCESSING' AVAILABLE = 'AVAILABLE' @dataclass class DisclosureDataSeries: gsrn: str = field(default=None) address: str = field(default=None) measurements: List[MeasurementValue] = field(default_factory=list) ggos: List[SummaryGroup] = field(default_factory=list)
invalid_content_type = "Doesn't have a valid content type, matching {regex}" invalid_filename = "Doesn't have a valid filename, matching {regex}" file_is_empty = "Size of the file is {size}" content_type_regex = re.compile(r"image/*") filename_regex = re.compile(r"^.*\.(jpeg|jpg|gif|png)$") def __call__(self, value) -> typing.Any: if not hasattr(value, 'content_type') or self.content_type_regex.match( value.content_type) is None: raise ValidationError( self.invalid_content_type.format( regex=self.content_type_regex.pattern)) if not hasattr(value, 'filename') or self.filename_regex.match( value.filename) is None: raise ValidationError( self.invalid_filename.format( regex=self.filename_regex.pattern)) if not (peek := value.file.peek()): raise ValidationError(self.file_is_empty.format(size=len(peek))) return value ImageType = NewType("ImageType", str, field=fields.Raw, validate=ImageValidator())
MACHINE_LEARNING = 'machine_learning' SAVED_QUERY = 'saved_query' QUERY = 'query' OPERATORS = ['equals'] TIMELINE_TEMPLATES: Final[dict] = { 'db366523-f1c6-4c1f-8731-6ce5ed9e5717': 'Generic Endpoint Timeline', '91832785-286d-4ebe-b884-1a208d111a70': 'Generic Network Timeline', '76e52245-7519-4251-91ab-262fb1a1728c': 'Generic Process Timeline', '495ad7a7-316e-4544-8a0f-9c098daee76e': 'Generic Threat Match Timeline' } NonEmptyStr = NewType('NonEmptyStr', str, validate=validate.Length(min=1)) BranchVer = NewType('BranchVer', str, validate=validate.Regexp(BRANCH_PATTERN)) CardinalityFields = NewType('CardinalityFields', List[NonEmptyStr], validate=validate.Length(min=0, max=3)) CodeString = NewType("CodeString", str) ConditionSemVer = NewType('ConditionSemVer', str, validate=validate.Regexp(CONDITION_VERSION_PATTERN)) Date = NewType('Date', str, validate=validate.Regexp(DATE_PATTERN)) FilterLanguages = Literal["kuery", "lucene"] Interval = NewType('Interval', str, validate=validate.Regexp(INTERVAL_PATTERN)) Markdown = NewType("MarkdownField", CodeString) Maturity = Literal['development', 'experimental', 'beta', 'production', 'deprecated'] MaxSignals = NewType("MaxSignals", int, validate=validate.Range(min=1)) Operator = Literal['equals'] OSType = Literal['windows', 'linux', 'macos'] PositiveInteger = NewType('PositiveInteger', int, validate=validate.Range(min=1)) RiskScore = NewType("MaxSignals", int, validate=validate.Range(min=1, max=100))
@dataclass class _NumpyArrayDTO: dtype: str data: List[Any] class NumpyField(Field): def __init__( self, *args, **kwargs ): super(NumpyField, self).__init__(*args, **kwargs) def _serialize(self, value: np.ndarray, *args, **kwargs): if value is None: return None return asdict(_NumpyArrayDTO(dtype=value.dtype.name, data=value.tolist())) def _deserialize(self, value, *args, **kwargs): if value is None: return None np_array_obj = _NumpyArrayDTO(**value) return np.array(np_array_obj.data, dtype=np.dtype(np_array_obj.dtype)) NumpyArray = NewType("NdArray", np.ndarray, field=NumpyField) __all__ = [NumpyField, NumpyArray]
PRODUCTION = 'production' CONSUMPTION = 'consumption' class SummaryResolution(Enum): """ TODO """ ALL = 'all' YEAR = 'year' MONTH = 'month' DAY = 'day' HOUR = 'hour' SummaryGroupValue = NewType('SummaryGroupValue', int, allow_none=True) @dataclass class SummaryGroup: """ TODO """ group: List[str] = field(default_factory=list) values: List[SummaryGroupValue] = field(default_factory=list) def __add__(self, other): """ :param SummaryGroup other: :rtype: SummaryGroup """
logger = logging.getLogger(__name__) class IPValidator(Validator): """ validator for strings containing IPs """ def __call__(self, value: Any) -> Any: try: IP(value) except ValueError: raise ValidationError("not an IP") return value IPType = NewType("IP", str, validate=IPValidator()) @dataclass class FirewallRule: """ dataclass for a firewall rule """ protocol: str port: int = field(metadata={"validate": marshmallow.validate.Range(min=1)}) @dataclass class FirewallConfig: """
class SCIMSchema(Enum): CORE_20_USER = '******' CORE_20_GROUP = 'urn:ietf:params:scim:schemas:core:2.0:Group' API_MESSAGES_20_SEARCH_REQUEST = 'urn:ietf:params:scim:api:messages:2.0:SearchRequest' API_MESSAGES_20_LIST_RESPONSE = 'urn:ietf:params:scim:api:messages:2.0:ListResponse' ERROR = 'urn:ietf:params:scim:api:messages:2.0:Error' NUTID_USER_V1 = 'https://scim.eduid.se/schema/nutid/user/v1' NUTID_GROUP_V1 = 'https://scim.eduid.se/schema/nutid/group/v1' NUTID_INVITE_CORE_V1 = 'https://scim.eduid.se/schema/nutid/invite/core-v1' NUTID_INVITE_V1 = 'https://scim.eduid.se/schema/nutid/invite/v1' NUTID_EVENT_CORE_V1 = 'https://scim.eduid.se/schema/nutid/event/core-v1' NUTID_EVENT_V1 = 'https://scim.eduid.se/schema/nutid/event/v1' DEBUG_V1 = 'https://scim.eduid.se/schema/nutid-DEBUG/v1' SCIMSchemaValue = NewType('SCIMSchemaValue', SCIMSchema, field=EnumField, enum=SCIMSchema, by_value=True) class SCIMResourceType(Enum): USER = '******' GROUP = 'Group' INVITE = 'Invite' EVENT = 'Event' class EmailType(Enum): HOME = 'home' WORK = 'work' OTHER = 'other'
""" The task model. """ from dataclasses import field from typing import ClassVar, List, Optional, Type from marshmallow import Schema from marshmallow_dataclass import NewType, dataclass from shipyard_cli.fields import ObjectId from shipyard_cli.validators import validate_devices objectid = NewType('objectid', str, ObjectId) @dataclass(order=True) class Task: """A real-time task that can be deployed as a container to a node.""" _id: Optional[objectid] = field(metadata={'required': False}) file_id: Optional[objectid] = field(metadata={'required': False}) name: Optional[str] = field(metadata={'required': False}) runtime: Optional[int] = field(metadata={'required': False}) deadline: Optional[int] = field(metadata={'required': False}) period: Optional[int] = field(metadata={'required': False}) devices: List[str] = field(default_factory=lambda: [], metadata={ 'required': False, 'validate': validate_devices }) capabilities: List[str] = field(default_factory=lambda: [],
RING = 6 # Reminds me of Beatstream ? @dataclass class Metadata: cover: Optional[str] # path to album art ? creator: Optional[str] # Chart author background: Optional[str] # path to background image version: Optional[str] # freeform difficulty name id: Optional[int] mode: int time: Optional[int] # creation timestamp ? song: SongInfo PositiveInt = NewType("PositiveInt", int, validate=Range(min=0)) BeatTime = Tuple[PositiveInt, PositiveInt, PositiveInt] StrictlyPositiveDecimal = NewType("StrictlyPositiveDecimal", Decimal, validate=Range(min=0, min_inclusive=False)) @dataclass class BPMEvent: beat: BeatTime bpm: StrictlyPositiveDecimal ButtonIndex = NewType("ButtonIndex", int, validate=Range(min=0, max=15))