class VulnerabilitySchema(AutoSchema): _id = fields.Integer(dump_only=True, attribute='id') _rev = fields.String(dump_only=True, default='') _attachments = fields.Method(serialize='get_attachments', deserialize='load_attachments', default=[]) owned = fields.Boolean(dump_only=True, default=False) owner = PrimaryKeyRelatedField('username', dump_only=True, attribute='creator') impact = SelfNestedField(ImpactSchema()) desc = fields.String(attribute='description') description = fields.String(dump_only=True) policyviolations = fields.List(fields.String, attribute='policy_violations') refs = fields.List(fields.String(), attribute='references') issuetracker = fields.Method(serialize='get_issuetracker', dump_only=True) parent = fields.Method(serialize='get_parent', deserialize='load_parent', required=True) parent_type = MutableField(fields.Method('get_parent_type'), fields.String(), required=True) tags = PrimaryKeyRelatedField('name', dump_only=True, many=True) easeofresolution = fields.String( attribute='ease_of_resolution', validate=OneOf(Vulnerability.EASE_OF_RESOLUTIONS), allow_none=True) hostnames = PrimaryKeyRelatedField('name', many=True, dump_only=True) service = fields.Nested(ServiceSchema(only=[ '_id', 'ports', 'status', 'protocol', 'name', 'version', 'summary' ]), dump_only=True) host = fields.Integer(dump_only=True, attribute='host_id') severity = SeverityField(required=True) status = fields.Method( serialize='get_status', validate=OneOf(Vulnerability.STATUSES + ['opened']), deserialize='load_status') type = fields.Method(serialize='get_type', deserialize='load_type', required=True) obj_id = fields.String(dump_only=True, attribute='id') target = fields.String(dump_only=True, attribute='target_host_ip') host_os = fields.String(dump_only=True, attribute='target_host_os') metadata = SelfNestedField(CustomMetadataSchema()) date = fields.DateTime(attribute='create_date', dump_only=True) # This is only used for sorting custom_fields = FaradayCustomField(table_name='vulnerability', attribute='custom_fields') class Meta: model = Vulnerability fields = ( '_id', 'status', 'issuetracker', 'description', 'parent', 'parent_type', 'tags', 'severity', '_rev', 'easeofresolution', 'owned', 'hostnames', 'owner', 'date', 'data', 'refs', 'desc', 'impact', 'confirmed', 'name', 'service', 'obj_id', 'type', 'policyviolations', '_attachments', 'target', 'host_os', 'resolution', 'metadata', 'custom_fields') def get_type(self, obj): return obj.__class__.__name__ def get_attachments(self, obj): res = {} for file_obj in obj.evidence: ret, errors = EvidenceSchema().dump(file_obj) if errors: raise ValidationError(errors, data=ret) res[file_obj.filename] = ret return res def load_attachments(self, value): return value def get_parent(self, obj): return obj.service_id or obj.host_id def get_parent_type(self, obj): assert obj.service_id is not None or obj.host_id is not None return 'Service' if obj.service_id is not None else 'Host' def get_status(self, obj): if obj.status == 'open': return 'opened' return obj.status def get_issuetracker(self, obj): return {} def load_status(self, value): if value == 'opened': return 'open' return value def load_type(self, value): if value == 'Vulnerability': return 'vulnerability' if value == 'VulnerabilityWeb': return 'vulnerability_web' def load_parent(self, value): try: # sometimes api requests send str or unicode. value = int(value) except ValueError: raise ValidationError("Invalid parent type") return value @post_load def post_load_impact(self, data): # Unflatten impact (move data[impact][*] to data[*]) impact = data.pop('impact', None) if impact: data.update(impact) return data @post_load def post_load_parent(self, data): # schema guarantees that parent_type exists. parent_class = None parent_type = data.pop('parent_type', None) parent_id = data.pop('parent', None) if not (parent_type and parent_id): # Probably a partial load, since they are required return if parent_type == 'Host': parent_class = Host parent_field = 'host_id' if parent_type == 'Service': parent_class = Service parent_field = 'service_id' if not parent_class: raise ValidationError('Unknown parent type') try: parent = db.session.query(parent_class).join(Workspace).filter( Workspace.name == self.context['workspace_name'], parent_class.id == parent_id ).one() except NoResultFound: raise ValidationError('Parent id not found: {}'.format(parent_id)) data[parent_field] = parent.id # TODO migration: check what happens when updating the parent from # service to host or viceverse return data
class GelfSchema(Schema): class Meta: unknown = EXCLUDE version = fields.Constant("1.1") host = fields.String(required=True, default=socket.gethostname) short_message = fields.Method('to_message') full_message = fields.String() timestamp = fields.Method('to_timestamp') level = fields.Method('to_syslog_level') lineno = fields.Integer(data_key="line") pathname = fields.String(data_key="file") @classmethod def _forge_key(cls, key, _value): return key @classmethod def to_syslog_level(cls, value): """description of to_syslog_level""" return SYSLOG_LEVELS.get(value.levelno, 1) @classmethod def to_timestamp(cls, value): """to_timestamp""" if value.created: return value.created else: return time.time() @classmethod def to_message(cls, value): """description of to_message""" # noinspection PyBroadException try: return value.getMessage() % vars(value) except Exception: return value.getMessage() @classmethod def format_key(cls, xpath, key, value): if xpath in (None, ""): if key in GELF_1_1_FIELDS: return key elif key in (None, ""): return "" else: return "_{}".format(cls._forge_key(key, value)) else: return "{}_{}".format(xpath, cls._forge_key(key, value)) @classmethod def to_flat_dict(cls, xpath, key, value): parts = dict() if isinstance(value, dict): for subkey, subvalue in value.items(): parts.update(cls.to_flat_dict( cls.format_key(xpath, key, value), subkey, subvalue )) elif isinstance(value, (list, tuple)): if len(value) < 20: for idx in range(len(value)): parts.update(cls.to_flat_dict( cls.format_key(xpath, key, value), idx, value[idx] )) else: try: parts[cls.format_key(xpath, key, value)] = str(value) except Exception: pass else: parts[cls.format_key(xpath, key, value)] = value return parts @post_dump def fix_additional_fields(self, data, **kwargs): """description of fix_additional_fields""" return self.to_flat_dict("", "", data)
class BadSerializer(Schema): bad_field = fields.Method("invalid")
class SurveyResponseSchema(ma.SQLAlchemyAutoSchema): class Meta: model = models.SurveyResponse exclude = ['created_location_coordinates', 'created_location_altitude', 'created_location_altitude_accuracy', 'created_location_dt', 'created_location_heading', 'created_location_position_accuracy', 'created_location_speed', 'updated_location_coordinates', 'updated_location_altitude', 'updated_location_altitude_accuracy', 'updated_location_dt', 'updated_location_heading', 'updated_location_position_accuracy', 'updated_location_speed'] created_by = ma.Nested(UserSchema, only=['id', 'email', 'username', 'name', 'color']) updated_by = ma.Nested(UserSchema, only=['id', 'email', 'username', 'name', 'color']) status = ma.Nested(SurveyResponseStatusSchema) survey = ma.auto_field('survey_id', dump_only=True) case_id = ma.auto_field('case_id', dump_only=True) activity_id = ma.auto_field('activity_id', dump_only=True) created_location = fields.Method('get_created_location') updated_location = fields.Method('get_updated_location') def get_created_location(self, obj): if obj.created_location_coordinates is not None: latitude = db.session.scalar(obj.created_location_coordinates.ST_Y()) longitude = db.session.scalar(obj.created_location_coordinates.ST_X()) else: latitude = None longitude = None if obj.created_location_dt is not None: created_location_dt_value = obj.created_location_dt.isoformat() else: created_location_dt_value = None return { "latitude": latitude, "longitude": longitude, "position_accuracy": obj.created_location_position_accuracy, "altitude": obj.created_location_altitude, "altitude_accuracy": obj.created_location_altitude_accuracy, "heading": obj.created_location_heading, "speed": obj.created_location_speed, "location_recorded_dt": created_location_dt_value } def get_updated_location(self, obj): if obj.updated_location_coordinates is not None: latitude = db.session.scalar(obj.updated_location_coordinates.ST_Y()) longitude = db.session.scalar(obj.updated_location_coordinates.ST_X()) else: latitude = None longitude = None if obj.updated_location_dt is not None: updated_location_dt_value = obj.updated_location_dt.isoformat() else: updated_location_dt_value = None return { "latitude": latitude, "longitude": longitude, "position_accuracy": obj.updated_location_position_accuracy, "altitude": obj.updated_location_altitude, "altitude_accuracy": obj.updated_location_altitude_accuracy, "heading": obj.updated_location_heading, "speed": obj.updated_location_speed, "location_recorded_dt": updated_location_dt_value }
class JSONSchema(Schema): """Converts to JSONSchema as defined by http://json-schema.org/.""" properties = fields.Method("get_properties") type = fields.Constant("object") required = fields.Method("get_required") def __init__(self, *args, **kwargs) -> None: """Setup internal cache of nested fields, to prevent recursion. :param bool props_ordered: if `True` order of properties will be save as declare in class, else will using sorting, default is `False`. Note: For the marshmallow scheme, also need to enable ordering of fields too (via `class Meta`, attribute `ordered`). """ self._nested_schema_classes: typing.Dict[str, typing.Dict[str, typing.Any]] = {} self.nested = kwargs.pop("nested", False) self.props_ordered = kwargs.pop("props_ordered", False) setattr(self.opts, "ordered", self.props_ordered) super().__init__(*args, **kwargs) def get_properties(self, obj) -> typing.Dict[str, typing.Dict[str, typing.Any]]: """Fill out properties field.""" properties = self.dict_class() if self.props_ordered: fields_items_sequence = obj.fields.items() else: fields_items_sequence = sorted(obj.fields.items()) for field_name, field in fields_items_sequence: schema = self._get_schema_for_field(obj, field) properties[field.metadata.get("name") or field.name] = schema return properties def get_required(self, obj) -> typing.Union[typing.List[str], _Missing]: """Fill out required field.""" required = [] for field_name, field in sorted(obj.fields.items()): if field.required: required.append(field.name) return required or missing def _from_python_type(self, obj, field, pytype) -> typing.Dict[str, typing.Any]: """Get schema definition from python type.""" json_schema = {"title": field.attribute or field.name or ""} for key, val in PY_TO_JSON_TYPES_MAP[pytype].items(): json_schema[key] = val if field.dump_only: json_schema["readonly"] = True if field.default is not missing: json_schema["default"] = field.default if ALLOW_ENUMS and isinstance(field, EnumField): json_schema["enum"] = self._get_enum_values(field) if field.allow_none: previous_type = json_schema["type"] json_schema["type"] = [previous_type, "null"] # NOTE: doubled up to maintain backwards compatibility metadata = field.metadata.get("metadata", {}) metadata.update(field.metadata) for md_key, md_val in metadata.items(): if md_key in ("metadata", "name"): continue json_schema[md_key] = md_val if isinstance(field, fields.List): json_schema["items"] = self._get_schema_for_field(obj, field.inner) if isinstance(field, fields.Dict): json_schema["additionalProperties"] = ( self._get_schema_for_field(obj, field.value_field) if field.value_field else {} ) return json_schema def _get_enum_values(self, field) -> typing.List[str]: assert ALLOW_ENUMS and isinstance(field, EnumField) if field.load_by == LoadDumpOptions.value: # Python allows enum values to be almost anything, so it's easier to just load from the # names of the enum's which will have to be strings. raise NotImplementedError( "Currently do not support JSON schema for enums loaded by value" ) return [value.name for value in field.enum] def _from_union_schema( self, obj, field ) -> typing.Dict[str, typing.List[typing.Any]]: """Get a union type schema. Uses anyOf to allow the value to be any of the provided sub fields""" assert ALLOW_UNIONS and isinstance(field, Union) return { "anyOf": [ self._get_schema_for_field(obj, sub_field) for sub_field in field._candidate_fields ] } def _get_python_type(self, field): """Get python type based on field subclass""" for map_class, pytype in MARSHMALLOW_TO_PY_TYPES_PAIRS: if issubclass(field.__class__, map_class): return pytype raise UnsupportedValueError("unsupported field type %s" % field) def _get_schema_for_field(self, obj, field): """Get schema and validators for field.""" if hasattr(field, "_jsonschema_type_mapping"): schema = field._jsonschema_type_mapping() elif "_jsonschema_type_mapping" in field.metadata: schema = field.metadata["_jsonschema_type_mapping"] else: if isinstance(field, fields.Nested): # Special treatment for nested fields. schema = self._from_nested_schema(obj, field) elif ALLOW_UNIONS and isinstance(field, Union): schema = self._from_union_schema(obj, field) else: pytype = self._get_python_type(field) schema = self._from_python_type(obj, field, pytype) # Apply any and all validators that field may have for validator in field.validators: if validator.__class__ in FIELD_VALIDATORS: schema = FIELD_VALIDATORS[validator.__class__]( schema, field, validator, obj ) else: base_class = getattr( validator, "_jsonschema_base_validator_class", None ) if base_class is not None and base_class in FIELD_VALIDATORS: schema = FIELD_VALIDATORS[base_class](schema, field, validator, obj) return schema def _from_nested_schema(self, obj, field): """Support nested field.""" if isinstance(field.nested, (str, bytes)): nested = get_class(field.nested) else: nested = field.nested if isclass(nested) and issubclass(nested, Schema): name = nested.__name__ only = field.only exclude = field.exclude nested_cls = nested nested_instance = nested(only=only, exclude=exclude) else: nested_cls = nested.__class__ name = nested_cls.__name__ nested_instance = nested outer_name = obj.__class__.__name__ # If this is not a schema we've seen, and it's not this schema (checking this for recursive schemas), # put it in our list of schema defs if name not in self._nested_schema_classes and name != outer_name: wrapped_nested = self.__class__(nested=True) wrapped_dumped = wrapped_nested.dump(nested_instance) wrapped_dumped["additionalProperties"] = _resolve_additional_properties( nested_cls ) self._nested_schema_classes[name] = wrapped_dumped self._nested_schema_classes.update(wrapped_nested._nested_schema_classes) # and the schema is just a reference to the def schema = {"type": "object", "$ref": "#/definitions/{}".format(name)} # NOTE: doubled up to maintain backwards compatibility metadata = field.metadata.get("metadata", {}) metadata.update(field.metadata) for md_key, md_val in metadata.items(): if md_key in ("metadata", "name"): continue schema[md_key] = md_val if field.many: schema = { "type": "array" if field.required else ["array", "null"], "items": schema, } return schema def dump(self, obj, **kwargs): """Take obj for later use: using class name to namespace definition.""" self.obj = obj return super().dump(obj, **kwargs) @post_dump def wrap(self, data, **_) -> typing.Dict[str, typing.Any]: """Wrap this with the root schema definitions.""" if self.nested: # no need to wrap, will be in outer defs return data cls = self.obj.__class__ name = cls.__name__ data["additionalProperties"] = _resolve_additional_properties(cls) self._nested_schema_classes[name] = data root = { "$schema": "http://json-schema.org/draft-07/schema#", "definitions": self._nested_schema_classes, "$ref": "#/definitions/{name}".format(name=name), } return root
class PlaceSchema(Schema): name = fields.String() x = MutableField(fields.Method('get_x'), fields.String()) def get_x(self, obj): return 5
class ActivitySchema(ma.SQLAlchemyAutoSchema): class Meta: model = models.Activity additional = ['surveys', 'documents', 'history'] exclude = ['created_location_coordinates', 'created_location_altitude', 'created_location_altitude_accuracy', 'created_location_dt', 'created_location_heading', 'created_location_position_accuracy', 'created_location_speed', 'updated_location_coordinates', 'updated_location_altitude', 'updated_location_altitude_accuracy', 'updated_location_dt', 'updated_location_heading', 'updated_location_position_accuracy', 'updated_location_speed'] activity_definition = ma.Nested(ActivityDefinitionSchema, only=['id', 'name']) # case = ma.Nested(CaseSchema) created_by = ma.Nested(UserSchema, only=['id', 'email', 'username', 'name', 'color']) updated_by = ma.Nested(UserSchema, only=['id', 'email', 'username', 'name', 'color']) completed_by = ma.Nested(UserSchema, only=['id', 'email', 'username', 'name', 'color']) notes = ma.Nested(NoteSchema, many=True, only=['id', 'note', 'created_at', 'updated_at', 'created_by', 'updated_by']) files = ma.Nested( UploadedFileSchema, many=True, only=['id', 'original_filename', 'remote_filename', 'created_at', 'created_by', 'url'] ) created_location = fields.Method('get_created_location') updated_location = fields.Method('get_updated_location') case = fields.Method('get_case') def get_created_location(self, obj): if obj.created_location_coordinates is not None: latitude = db.session.scalar(obj.created_location_coordinates.ST_Y()) longitude = db.session.scalar(obj.created_location_coordinates.ST_X()) else: latitude = None longitude = None if obj.created_location_dt is not None: created_location_dt_value = obj.created_location_dt.isoformat() else: created_location_dt_value = None return { "latitude": latitude, "longitude": longitude, "position_accuracy": obj.created_location_position_accuracy, "altitude": obj.created_location_altitude, "altitude_accuracy": obj.created_location_altitude_accuracy, "heading": obj.created_location_heading, "speed": obj.created_location_speed, "location_recorded_dt": created_location_dt_value } def get_updated_location(self, obj): if obj.updated_location_coordinates is not None: latitude = db.session.scalar(obj.updated_location_coordinates.ST_Y()) longitude = db.session.scalar(obj.updated_location_coordinates.ST_X()) else: latitude = None longitude = None if obj.updated_location_dt is not None: updated_location_dt_value = obj.updated_location_dt.isoformat() else: updated_location_dt_value = None return { "latitude": latitude, "longitude": longitude, "position_accuracy": obj.updated_location_position_accuracy, "altitude": obj.updated_location_altitude, "altitude_accuracy": obj.updated_location_altitude_accuracy, "heading": obj.updated_location_heading, "speed": obj.updated_location_speed, "location_recorded_dt": updated_location_dt_value } def get_case(self, obj): if obj.case is not None: return { 'id': obj.case.id, 'key': obj.case.key, 'name': obj.case.name, 'definition': { 'id': obj.case.case_definition.id, 'key': obj.case.case_definition.key, 'name': obj.case.case_definition.name } } else: return {'id': None, 'key': None, 'name': None, 'definition': { 'id': None, 'key': None, 'name': None }}
class DataCiteSchema(Schema): """Base class for schemas.""" identifier = fields.Method('get_identifier', attribute='metadata.doi') titles = fields.List( fields.Nested(TitleSchema), attribute='metadata.title') publisher = fields.Constant('Zenodo') publicationYear = fields.Function( lambda o: str(arrow.get(o['metadata']['publication_date']).year)) subjects = fields.Method('get_subjects') dates = fields.Method('get_dates') language = fields.Method('get_language') version = fields.Str(attribute='metadata.version') resourceType = fields.Method('get_type') alternateIdentifiers = fields.List( fields.Nested(AlternateIdentifierSchema), attribute='metadata.alternate_identifiers', ) relatedIdentifiers = fields.Method('get_related_identifiers') rightsList = fields.Method('get_rights') descriptions = fields.Method('get_descriptions') @post_dump def cleanup(self, data): """Clean the data.""" # Remove the language if Alpha-2 code was not found if 'language' in data and data['language'] is None: del data['language'] return data def get_identifier(self, obj): """Get record main identifier.""" doi = obj['metadata'].get('doi', '') if is_doi_locally_managed(doi): return { 'identifier': doi, 'identifierType': 'DOI' } else: recid = obj.get('metadata', {}).get('recid', '') return { 'identifier': ui_link_for('record_html', id=recid), 'identifierType': 'URL', } def get_language(self, obj): """Export language to the Alpha-2 code (if available).""" lang = obj['metadata'].get('language', None) if lang: try: l = pycountry.languages.get(alpha_3=lang) except KeyError: return None if not hasattr(l, 'alpha_2'): return None return l.alpha_2 return None def get_descriptions(self, obj): """Get descriptions.""" items = [] desc = obj['metadata']['description'] max_descr_size = current_app.config.get( 'DATACITE_MAX_DESCRIPTION_SIZE', 20000) if desc: items.append({ 'description': desc[:max_descr_size], 'descriptionType': 'Abstract' }) notes = obj['metadata'].get('notes') if notes: items.append({ 'description': notes[:max_descr_size], 'descriptionType': 'Other' }) refs = obj['metadata'].get('references') if refs: items.append({ 'description': json.dumps({ 'references': [ r['raw_reference'] for r in refs if 'raw_reference' in r] })[:max_descr_size], 'descriptionType': 'Other' }) return items def get_rights(self, obj): """Get rights.""" items = [] # license license_url = obj['metadata'].get('license', {}).get('url') license_text = obj['metadata'].get('license', {}).get('title') if license_url and license_text: items.append({ 'rightsURI': license_url, 'rights': license_text, }) # info:eu-repo items.append({ 'rightsURI': 'info:eu-repo/semantics/{}Access'.format( obj['metadata']['access_right']), 'rights': '{0} access'.format( obj['metadata']['access_right']).title() }) return items def get_type(self, obj): """Resource type.""" t = ObjectType.get_by_dict(obj['metadata']['resource_type']) type_ = { 'resourceTypeGeneral': t['datacite']['general'], 'resourceType': t['datacite'].get('type'), } oa_type = ObjectType.get_openaire_subtype(obj['metadata']) # NOTE: This overwrites the resourceType if the configuration # of the OpenAIRE subtypes overlaps with regular subtypes. if oa_type: type_['resourceType'] = oa_type return type_ def get_related_identifiers(self, obj): """Related identifiers.""" accepted_types = [ 'doi', 'ark', 'ean13', 'eissn', 'handle', 'isbn', 'issn', 'istc', 'lissn', 'lsid', 'purl', 'upc', 'url', 'urn', 'ads', 'arxiv', 'bibcode', ] s = RelatedIdentifierSchema() items = [] for r in obj['metadata'].get('related_identifiers', []): if r['scheme'] in accepted_types: items.append(s.dump(r).data) doi = obj['metadata'].get('doi', '') if not is_doi_locally_managed(doi): items.append(s.dump({ 'identifier': doi, 'scheme': 'doi', 'relation': 'IsIdenticalTo', }).data) # OpenAIRE community identifiers openaire_comms = resolve_openaire_communities( obj['metadata'].get('communities', [])) for oa_comm in openaire_comms: items.append(s.dump({ 'identifier': openaire_community_identifier(oa_comm), 'scheme': 'url', 'relation': 'IsPartOf', }).data) return items def get_subjects(self, obj): """Get subjects.""" items = [] for s in obj['metadata'].get('keywords', []): items.append({'subject': s}) for s in obj['metadata'].get('subjects', []): items.append({ 'subject': s['identifier'], 'subjectScheme': s['scheme'], }) return items def get_dates(self, obj): """Get dates.""" s = DateSchema() if obj['metadata']['access_right'] == 'embargoed' and \ obj['metadata'].get('embargo_date'): return [ s.dump(dict( date=obj['metadata']['embargo_date'], type='Available')).data, s.dump(dict( date=obj['metadata']['publication_date'], type='Accepted')).data, ] else: return [s.dump(dict( date=obj['metadata']['publication_date'], type='Issued')).data, ]
def test_method_prefers_serialize_over_method_name(self): m = fields.Method(serialize='serialize', method_name='method') assert m.serialize_method_name == 'serialize'
class VulnerabilityTemplateSchema(AutoSchema): _id = fields.Integer(dump_only=True, attribute='id') id = fields.Integer(dump_only=True, attribute='id') _rev = fields.String(default='', dump_only=True) cwe = fields.String( dump_only=True, default='' ) # deprecated field, the legacy data is added to refs on import exploitation = SeverityField(attribute='severity', required=True) references = fields.Method('get_references', deserialize='load_references', required=True) refs = fields.List(fields.String(), dump_only=True, attribute='references') desc = fields.String(dump_only=True, attribute='description') data = fields.String(attribute='data') impact = SelfNestedField(ImpactSchema()) easeofresolution = fields.String(attribute='ease_of_resolution', validate=OneOf( Vulnerability.EASE_OF_RESOLUTIONS), allow_none=True) policyviolations = fields.List(fields.String, attribute='policy_violations') # Here we use vulnerability instead of vulnerability_template to avoid duplicate row # in the custom_fields_schema table. # All validation will be against vulnerability table. customfields = FaradayCustomField(table_name='vulnerability', attribute='custom_fields') class Meta: model = VulnerabilityTemplate fields = ('id', '_id', '_rev', 'cwe', 'description', 'desc', 'exploitation', 'name', 'references', 'refs', 'resolution', 'impact', 'easeofresolution', 'policyviolations', 'data', 'customfields') def get_references(self, obj): return ', '.join( map(lambda ref_tmpl: ref_tmpl.name, obj.reference_template_instances)) def load_references(self, value): if isinstance(value, list): references = value elif isinstance(value, (unicode, str)): if len(value) == 0: # Required because "".split(",") == [""] return [] references = [ref.strip() for ref in value.split(',')] else: raise ValidationError('references must be a either a string ' 'or a list') if any(len(ref) == 0 for ref in references): raise ValidationError('Empty name detected in reference') return references @post_load def post_load_impact(self, data): # Unflatten impact (move data[impact][*] to data[*]) impact = data.pop('impact', None) if impact: data.update(impact) return data
class LegacyMetadataSchemaV1(common.CommonMetadataSchemaV1): """Legacy JSON metadata.""" upload_type = fields.String( attribute='resource_type.type', required=True, validate=validate.OneOf(choices=ObjectType.get_types()), ) publication_type = fields.Method( 'dump_publication_type', attribute='resource_type.subtype', validate=validate.OneOf( choices=ObjectType.get_subtypes('publication')), ) image_type = fields.Method( 'dump_image_type', attribute='resource_type.subtype', validate=validate.OneOf(choices=ObjectType.get_subtypes('image')), ) license = fields.Method('dump_license', 'load_license') communities = fields.Method('dump_communities', 'load_communities') grants = fields.Method('dump_grants', 'load_grants') prereserve_doi = fields.Method('dump_prereservedoi', 'load_prereservedoi') journal_title = SanitizedUnicode(attribute='journal.title') journal_volume = SanitizedUnicode(attribute='journal.volume') journal_issue = SanitizedUnicode(attribute='journal.issue') journal_pages = SanitizedUnicode(attribute='journal.pages') conference_title = SanitizedUnicode(attribute='meeting.title') conference_acronym = SanitizedUnicode(attribute='meeting.acronym') conference_dates = SanitizedUnicode(attribute='meeting.dates') conference_place = SanitizedUnicode(attribute='meeting.place') conference_url = SanitizedUrl(attribute='meeting.url') conference_session = SanitizedUnicode(attribute='meeting.session') conference_session_part = SanitizedUnicode( attribute='meeting.session_part') imprint_isbn = SanitizedUnicode(attribute='imprint.isbn') imprint_place = SanitizedUnicode(attribute='imprint.place') imprint_publisher = SanitizedUnicode(attribute='imprint.publisher') partof_pages = SanitizedUnicode(attribute='part_of.pages') partof_title = SanitizedUnicode(attribute='part_of.title') thesis_university = SanitizedUnicode(attribute='thesis.university') thesis_supervisors = fields.Nested(common.PersonSchemaV1, many=True, attribute='thesis.supervisors') def _dump_subtype(self, obj, type_): """Get subtype.""" if obj.get('resource_type', {}).get('type') == type_: return obj.get('resource_type', {}).get('subtype', missing) return missing def dump_publication_type(self, obj): """Get publication type.""" return self._dump_subtype(obj, 'publication') def dump_image_type(self, obj): """Get publication type.""" return self._dump_subtype(obj, 'image') def dump_license(self, obj): """Dump license.""" return obj.get('license', {}).get('id', missing) def load_license(self, data): """Load license.""" if isinstance(data, six.string_types): license = data if isinstance(data, dict): license = data['id'] return {'$ref': 'https://dx.zenodo.org/licenses/{0}'.format(license)} def dump_grants(self, obj): """Get grants.""" res = [] for g in obj.get('grants', []): if g.get('program', {}) == 'FP7' and \ g.get('funder', {}).get('doi') == '10.13039/501100000780': res.append(dict(id=g['code'])) else: res.append(dict(id=g['internal_id'])) return res or missing def load_grants(self, data): """Load grants.""" if not isinstance(data, list): raise ValidationError(_('Not a list.')) res = set() for g in data: if not isinstance(g, dict): raise ValidationError(_('Element not an object.')) g = g.get('id') if not g: continue # FP7 project grant if not g.startswith('10.13039/'): g = '10.13039/501100000780::{0}'.format(g) res.add(g) return [{ '$ref': 'https://dx.zenodo.org/grants/{0}'.format(grant_id) } for grant_id in res] or missing def dump_communities(self, obj): """Dump communities type.""" return [dict(identifier=x) for x in obj.get('communities', [])] \ or missing def load_communities(self, data): """Load communities type.""" if not isinstance(data, list): raise ValidationError(_('Not a list.')) return list( sorted([x['identifier'] for x in data if x.get('identifier')])) or missing def dump_prereservedoi(self, obj): """Dump pre-reserved DOI.""" recid = obj.get('recid') if recid: prefix = None if not current_app: prefix = '10.5072' # Test prefix return dict( recid=recid, doi=doi_generator(recid, prefix=prefix), ) return missing def load_prereservedoi(self, obj): """Load pre-reserved DOI. The value is not important as we do not store it. Since the deposit and record id are now the same """ return missing @pre_dump() def predump_related_identifiers(self, data): """Split related/alternate identifiers. This ensures that we can just use the base schemas definitions of related/alternate identifies. """ relids = data.pop('related_identifiers', []) alids = data.pop('alternate_identifiers', []) for a in alids: a['relation'] = 'isAlternateIdentifier' if relids or alids: data['related_identifiers'] = relids + alids return data @pre_load() def preload_related_identifiers(self, data): """Split related/alternate identifiers. This ensures that we can just use the base schemas definitions of related/alternate identifies for loading. """ # Legacy API does not accept alternate_identifiers, so force delete it. data.pop('alternate_identifiers', None) for r in data.pop('related_identifiers', []): # Problem that API accepted one relation while documentation # presented a different relation. if r.get('relation') in [ 'isAlternativeIdentifier', 'isAlternateIdentifier' ]: k = 'alternate_identifiers' r.pop('relation') else: k = 'related_identifiers' data.setdefault(k, []) data[k].append(r) @pre_load() def preload_resource_type(self, data): """Prepare data for easier deserialization.""" if data.get('upload_type') != 'publication': data.pop('publication_type', None) if data.get('upload_type') != 'image': data.pop('image_type', None) @pre_load() def preload_license(self, data): """Default license.""" acc = data.get('access_right', AccessRight.OPEN) if acc in [AccessRight.OPEN, AccessRight.EMBARGOED]: if 'license' not in data: if data.get('upload_type') == 'dataset': data['license'] = 'CC0-1.0' else: data['license'] = 'CC-BY-4.0' @post_load() def merge_keys(self, data): """Merge dot keys.""" prefixes = [ 'resource_type', 'journal', 'meeting', 'imprint', 'part_of', 'thesis', ] for p in prefixes: for k in list(data.keys()): if k.startswith('{0}.'.format(p)): key, subkey = k.split('.') if key not in data: data[key] = dict() data[key][subkey] = data.pop(k) # Pre-reserve DOI is implemented differently now. data.pop('prereserve_doi', None) @validates('communities') def validate_communities(self, values): """Validate communities.""" for v in values: if not isinstance(v, six.string_types): raise ValidationError(_('Invalid community identifier.'), field_names=['communities']) @validates_schema def validate_data(self, obj): """Validate resource type.""" type_ = obj.get('resource_type', {}).get('type') if type_ in ['publication', 'image']: type_dict = { 'type': type_, 'subtype': obj.get('resource_type', {}).get('subtype') } field_names = ['{0}_type'.format(type_)] else: type_dict = {'type': type_} field_names = ['upload_type'] if ObjectType.get_by_dict(type_dict) is None: raise ValidationError( _('Invalid upload, publication or image type.'), field_names=field_names, )
class BadSchema(Schema): uppername = fields.Method('uppercase_name', deserialize='lowercase_name')
class MiniUserSchema(Schema): uppername = fields.Method('uppercase_name') def uppercase_name(self, obj): return obj.upper()
class CommonRecordSchema(Schema, StrictKeysMixin): """Base record schema.""" id = fields.Str(attribute='pid.pid_value', dump_only=True) schema = fields.Method('get_schema', dump_only=True) experiment = fields.Str(attribute='metadata._experiment', dump_only=True) status = fields.Str(attribute='metadata._deposit.status', dump_only=True) created_by = fields.Method('get_created_by', dump_only=True) is_owner = fields.Method('is_current_user_owner', dump_only=True) metadata = fields.Method('get_metadata', dump_only=True) links = fields.Raw(dump_only=True) files = fields.Method('get_files', dump_only=True) access = fields.Method('get_access', dump_only=True) created = fields.Str(dump_only=True) updated = fields.Str(dump_only=True) revision = fields.Integer(dump_only=True) labels = fields.Method('get_labels', dump_only=True) def is_current_user_owner(self, obj): user_id = obj['metadata']['_deposit'].get('created_by') if user_id and current_user: return user_id == current_user.id return False def get_files(self, obj): return obj['metadata'].get('_files', []) def get_schema(self, obj): schema = resolve_schema_by_url(obj['metadata']['$schema']) result = {'name': schema.name, 'version': schema.version} return result def get_metadata(self, obj): result = { k: v for k, v in obj.get('metadata', {}).items() if k not in [ 'control_number', '$schema', '_deposit', '_experiment', '_access', '_files', '_fetched_from', '_user_edited' ] } return result def get_created_by(self, obj): user_id = obj['metadata']['_deposit'].get('created_by') if user_id: user = User.query.filter_by(id=user_id).one() return user.email return None def get_access(self, obj): """Return access object.""" access = obj['metadata']['_access'] for permission in access.values(): if permission['users']: for index, user_id in enumerate(permission['users']): permission['users'][index] = get_user_email_by_id(user_id) if permission['roles']: for index, role_id in enumerate(permission['roles']): permission['roles'][index] = get_role_name_by_id(role_id) return access def get_labels(self, obj): """Get labels.""" labels = set() for label in LABELS.values(): condition = label.get('condition') if not condition or condition(obj): labels = labels | _dot_access_helper(obj, label.get('path'), label.get('formatter')) return list(labels)
class MySchema(Schema): mfield = fields.Method("raise_error") def raise_error(self, obj): raise AttributeError()
class MetadataSchemaV1(BaseSchema): """Schema for the record metadata.""" class Meta: """Meta class to accept unknwon fields.""" unknown = INCLUDE # Administrative fields _access = Nested(AccessSchemaV1, required=True) _owners = fields.List(fields.Integer, validate=validate.Length(min=1), required=True) _created_by = fields.Integer(required=True) _default_preview = SanitizedUnicode() _files = fields.List(Nested(FilesSchemaV1, dump_only=True)) _internal_notes = fields.List(Nested(InternalNoteSchemaV1)) _embargo_date = DateString(data_key="embargo_date", attribute="embargo_date") _communities = GenMethod('dump_communities') _contact = SanitizedUnicode(data_key="contact", attribute="contact") # Metadata fields access_right = SanitizedUnicode(required=True) identifiers = Identifiers() creators = fields.List(Nested(CreatorSchemaV1), required=True) titles = fields.List(Nested(TitleSchemaV1), required=True) resource_type = Nested(ResourceTypeSchemaV1, required=True) recid = SanitizedUnicode() publication_date = EDTFLevel0DateString(required=True) subjects = fields.List(Nested(SubjectSchemaV1)) contributors = fields.List(Nested(ContributorSchemaV1)) dates = fields.List(Nested(DateSchemaV1)) language = SanitizedUnicode(validate=validate_iso639_3) related_identifiers = fields.List(Nested(RelatedIdentifierSchemaV1)) version = SanitizedUnicode() licenses = fields.List(Nested(LicenseSchemaV1)) descriptions = fields.List(Nested(DescriptionSchemaV1)) locations = fields.List(Nested(LocationSchemaV1)) references = fields.List(Nested(ReferenceSchemaV1)) extensions = fields.Method('dump_extensions', 'load_extensions') def dump_extensions(self, obj): """Dumps the extensions value. :params obj: invenio_records_files.api.Record instance """ current_app_metadata_extensions = ( current_app.extensions['invenio-rdm-records'].metadata_extensions) ExtensionSchema = current_app_metadata_extensions.to_schema() return ExtensionSchema().dump(obj.get('extensions', {})) def load_extensions(self, value): """Loads the 'extensions' field. :params value: content of the input's 'extensions' field """ current_app_metadata_extensions = ( current_app.extensions['invenio-rdm-records'].metadata_extensions) ExtensionSchema = current_app_metadata_extensions.to_schema() return ExtensionSchema().load(value) def dump_communities(self, obj): """Dumps communities related to the record.""" # NOTE: If the field is already there, it's coming from ES if '_communities' in obj: return CommunityStatusV1().dump(obj['_communities']) record = self.context.get('record') if record: _record = Record(record, model=record.model) return CommunityStatusV1().dump( RecordCommunitiesCollection(_record).as_dict()) @validates('_embargo_date') def validate_embargo_date(self, value): """Validate that embargo date is in the future.""" if arrow.get(value).date() <= arrow.utcnow().date(): raise ValidationError(_('Embargo date must be in the future.'), field_names=['embargo_date']) @validates('access_right') def validate_access_right(self, value): """Validate that access right is one of the allowed ones.""" access_right_key = {'access_right': value} validate_entry('access_right', access_right_key) @post_load def post_load_publication_date(self, obj, **kwargs): """Add '_publication_date_search' field.""" prepare_publication_date(obj) return obj
def test_method_with_no_serialize_is_missing(self): m = fields.Method() m.parent = Schema() assert m.serialize("", "", "") is missing_
class BadSerializer(Schema): foo = 'not callable' bad_field = fields.Method('foo')
class BibTexCommonSchema(BaseSchema): address = fields.Method("get_address") archivePrefix = fields.Method("get_archive_prefix") author = fields.Method("get_author") authors_with_role_author = fields.Method("get_authors_with_role_author") authors_with_role_editor = fields.Method("get_authors_with_role_editor") booktitle = fields.Method("get_book_title") collaboration = fields.Method("get_collaboration") doc_type = fields.Raw() doi = fields.Method("get_doi") edition = fields.Method("get_edition") eprint = fields.Method("get_eprint") isbn = fields.Method("get_isbn") journal = fields.Method("get_journal") month = fields.Method("get_month") note = fields.Method("get_note") number = fields.Method("get_number") pages = fields.Method("get_pages") primaryClass = fields.Method("get_primary_class") publisher = fields.Method("get_publisher") reportNumber = fields.Method("get_report_number") school = fields.Method("get_school") series = fields.Method("get_series") texkey = fields.Method("get_texkey") title = fields.Method("get_title") type_ = fields.Method("get_type", attribute="type", dump_to="type") url = fields.Method("get_url") volume = fields.Method("get_volume") year = fields.Method("get_year") document_type_map = { "article": "article", "book": "book", "book chapter": lambda data: "article" if get_value(data, "publication_info.journal_title") else "inbook", "proceedings": "proceedings", "report": "article", "note": "article", "conference paper": lambda data: "article" if get_value( data, "publication_info.journal_title") else "inproceedings", "thesis": lambda data: "article" if get_value(data, "publication_info.journal_title") else ("phdthesis" if get_value(data, "thesis_info.degree_type") in ("phd", "habilitation") else "mastersthesis"), } @staticmethod def get_date(data, doc_type): publication_year = BibTexCommonSchema.get_best_publication_info( data).get("year") thesis_date = get_value(data, "thesis_info.date") imprint_date = get_value(data, "imprints.date[0]") earliest_date = data.earliest_date date_map = { "mastersthesis": thesis_date, "phdthesis": thesis_date, "book": imprint_date, "inbook": imprint_date, } date_choice = date_map.get( doc_type) or publication_year or earliest_date if date_choice: return PartialDate.loads(str(date_choice)) @staticmethod def get_authors_with_role(authors, role): return [ latex_encode(author["full_name"]) for author in authors if role in author.get("inspire_roles", ["author"]) ] @staticmethod def get_document_type(data, doc_type): if doc_type in BibTexCommonSchema.document_type_map: doc_type_value = BibTexCommonSchema.document_type_map[doc_type] return doc_type_value(data) if callable( doc_type_value) else doc_type_value return "misc" @staticmethod def get_bibtex_document_type(data): bibtex_doc_types = [ BibTexCommonSchema.get_document_type(data, doc_type) for doc_type in data["document_type"] ] + ["misc"] chosen_type = ("article" if "article" in bibtex_doc_types else bibtex_doc_types[0]) return chosen_type @staticmethod def get_best_publication_info(data): publication_info = get_value(data, "publication_info", []) only_publications = [ entry for entry in publication_info if entry.get("material", "publication") == "publication" ] if not only_publications: return {} return sorted(only_publications, key=len, reverse=True)[0] def get_authors_with_role_author(self, data): return self.get_authors_with_role(data.get("authors", []), "author") def get_authors_with_role_editor(self, data): editors = self.get_authors_with_role(data.get("authors", []), "editor") if not editors and data.get("doc_type") in [ "inbook", "inproceedings", "article", ]: editors = self.get_book_editors(data) return editors def get_eprint(self, data): return get_value(data, "arxiv_eprints.value[0]", default=None) def get_archive_prefix(self, data): eprint = get_value(data, "arxiv_eprints.value[0]", default=None) if eprint: return "arXiv" return None def get_collaboration(self, data): collaboration = ", ".join( get_value(data, "collaborations.value", default=[])) return latex_encode(collaboration) def get_doi(self, data): return get_value(data, "dois.value[0]") def get_month(self, data): doc_type = data.get("doc_type") date = self.get_date(data, doc_type) if date: return date.month def get_year(self, data): doc_type = data.get("doc_type") date = self.get_date(data, doc_type) if date: return date.year def get_texkey(self, data): control_number = str(data.get("control_number")) return get_value(data, "texkeys[0]", default=control_number) def get_note(self, data): notices = ("erratum", "addendum") entries = [ entry for entry in get_value(data, "publication_info", []) if entry.get("material") in notices ] if not entries: return None note_strings = [ text_type("{field}: {journal} {volume}, {pages} {year}").format( field=entry["material"].title(), journal=entry.get("journal_title"), volume=entry.get("journal_volume"), pages=LiteratureReader.get_page_artid_for_publication_info( entry, "--"), year="({})".format(entry["year"]) if "year" in entry else "", ).strip() for entry in entries ] note_string = "[" + ", ".join(note_strings) + "]" note_string = re.sub(" +", " ", note_string) return latex_encode(re.sub(",,", ",", note_string)) def get_primary_class(self, data): eprint = get_value(data, "arxiv_eprints.value[0]") if eprint and is_arxiv_post_2007(eprint): return get_value(data, "arxiv_eprints[0].categories[0]") def get_title(self, data): title_dict = get_value(data, "titles[0]") if not title_dict: return None title_parts = [title_dict["title"]] if "subtitle" in title_dict: title_parts.append(title_dict["subtitle"]) return ": ".join(f"{{{latex_encode(part, contains_math=True)}}}" for part in title_parts) def get_url(self, data): return get_value(data, "urls.value[0]") def get_author(self, data): return " and ".join(f"{{{latex_encode(author)}}}" for author in data.get("corporate_author", [])) def get_number(self, data): return BibTexCommonSchema.get_best_publication_info(data).get( "journal_issue") def get_address(self, data): conference = ConferenceReader(data) pubinfo_city = latex_encode(conference.city) pubinfo_country_code = latex_encode(conference.country) if pubinfo_city and pubinfo_country_code: return f"{pubinfo_city}, {pubinfo_country_code}" return latex_encode(get_value(data, "imprints[0].place")) def get_type(self, data): doc_type = data.get("doc_type") degree_type = get_value(data, "thesis_info.degree_type", "other") if doc_type == "mastersthesis" and degree_type not in ("master", "diploma"): return "{} thesis".format(degree_type.title()) def get_report_number(self, data): report_number = ", ".join(report["value"] for report in data.get("report_numbers", []) if not report.get("hidden", False)) return latex_encode(report_number) def get_school(self, data): schools = [ school["name"] for school in get_value(data, "thesis_info.institutions", []) ] if schools: return latex_encode(", ".join(schools)) def get_publisher(self, data): return latex_encode(get_value(data, "imprints.publisher[0]")) def get_series(self, data): return latex_encode(get_value(data, "book_series.title[0]")) def get_book_title(self, data): parent_record = get_parent_record(data) parent_title = self.get_title(parent_record) return parent_title def get_book_editors(self, data): parent_record = get_parent_record(data) parent_editors = self.get_authors_with_role_editor(parent_record) return parent_editors def get_volume(self, data): publication_volume = BibTexCommonSchema.get_best_publication_info( data).get("journal_volume") bookseries_volume = get_value(data, "book_series.volume[0]") return publication_volume or bookseries_volume def get_pages(self, data): return LiteratureReader.get_page_artid_for_publication_info( BibTexCommonSchema.get_best_publication_info(data), "--") def get_edition(self, data): return get_value(data, "editions[0]") def get_journal(self, data): return latex_encode( BibTexCommonSchema.get_best_publication_info(data).get( "journal_title").replace(".", ". ").rstrip(" ")) def get_isbn(self, data): def hyphenate_if_possible(no_hyphens): try: return normalize_isbn(no_hyphens) except ISBNError: return no_hyphens isbns = get_value(data, "isbns.value", []) if isbns: return ", ".join(hyphenate_if_possible(isbn) for isbn in isbns) @pre_dump def filter_data(self, data): processed_data = data.copy() processed_data[ "doc_type"] = BibTexCommonSchema.get_bibtex_document_type( processed_data) return processed_data
class UpdateItemSchema(Schema): id = fields.Int() item_title = fields.String(required=True, validate=cannot_be_blank) price = fields.String(required=True, validate=cannot_be_blank) description = fields.String(required=True, validate=cannot_be_blank) status_item = fields.Method("bool_to_status") created_at = fields.DateTime() updated_at = fields.DateTime() deleted_at = fields.DateTime() # video = fields.String(load_only) # paper = fields def bool_to_status(self, obj): status = "ACTIVE" if obj.status_item != True: status = "INACTIVE" return status #end def @validates('item_title') def validate_item_title(self, item_title): # allow all character pattern = r"^[a-z-A-Z_0-9 ]+$" if len(item_title) < 2: raise ValidationError( 'Invalid {}. min is 2 character'.format(item_title)) if len(item_title) > 40: raise ValidationError( 'Invalid {}, max is 40 character'.format(item_title)) if re.match(pattern, item_title) is None: raise ValidationError( 'Invalid {}. only alphabet is allowed'.format(item_title)) #end def @validates('description') def validate_description(self, description): # allow all characters except number pattern = r"." if len(description) < 2: raise ValidationError('Invalid description, min is 2 characters') if len(description) > 1000: raise ValidationError('Invalid description, min is 1000 character') if re.match(pattern, description) is None: raise ValidationError(' see the rule of description') @validates('price') def validate_price(self, price): # allow all character pattern = r"^[0-9]+$" if len(price) < 2: raise ValidationError( 'Invalid {}. min is 2 character'.format(price)) if len(price) > 40: raise ValidationError( 'Invalid {}, max is 40 character'.format(price)) if re.match(pattern, price) is None: raise ValidationError( 'Invalid {}. only alphabet is allowed'.format(price))
class CaseSchema(ma.SQLAlchemyAutoSchema): class Meta: model = models.Case additional = ['surveys', 'documents', 'history'] exclude = ['created_location_coordinates', 'created_location_altitude', 'created_location_altitude_accuracy', 'created_location_dt', 'created_location_heading', 'created_location_position_accuracy', 'created_location_speed', 'updated_location_coordinates', 'updated_location_altitude', 'updated_location_altitude_accuracy', 'updated_location_dt', 'updated_location_heading', 'updated_location_position_accuracy', 'updated_location_speed', 'assigned_at'] assigned_to = fields.Method('get_assigned_to') case_definition = ma.Nested(CaseDefinitionSchema, only=['id', 'key', 'name']) created_by = ma.Nested(UserSchema, only=['id', 'email', 'username', 'name', 'color']) updated_by = ma.Nested(UserSchema, only=['id', 'email', 'username', 'name', 'color']) activities = ma.Nested(ActivitySchema, many=True, only=['id', 'name', 'is_complete', 'completed_at', 'completed_by', 'created_at', 'updated_at']) status = ma.Nested(CaseStatusSchema) notes = ma.Nested(NoteSchema, many=True, only=['id', 'note', 'created_at', 'updated_at', 'created_by', 'updated_by']) files = ma.Nested( UploadedFileSchema, many=True, only=['id', 'original_filename', 'remote_filename', 'created_at', 'created_by', 'url'] ) created_location = fields.Method('get_created_location') updated_location = fields.Method('get_updated_location') def get_created_location(self, obj): if obj.created_location_coordinates is not None: latitude = db.session.scalar(obj.created_location_coordinates.ST_Y()) longitude = db.session.scalar(obj.created_location_coordinates.ST_X()) else: latitude = None longitude = None if obj.created_location_dt is not None: created_location_dt_value = obj.created_location_dt.isoformat() else: created_location_dt_value = None return { "latitude": latitude, "longitude": longitude, "position_accuracy": obj.created_location_position_accuracy, "altitude": obj.created_location_altitude, "altitude_accuracy": obj.created_location_altitude_accuracy, "heading": obj.created_location_heading, "speed": obj.created_location_speed, "location_recorded_dt": created_location_dt_value } def get_updated_location(self, obj): if obj.updated_location_coordinates is not None: latitude = db.session.scalar(obj.updated_location_coordinates.ST_Y()) longitude = db.session.scalar(obj.updated_location_coordinates.ST_X()) else: latitude = None longitude = None if obj.updated_location_dt is not None: updated_location_dt_value = obj.updated_location_dt.isoformat() else: updated_location_dt_value = None return { "latitude": latitude, "longitude": longitude, "position_accuracy": obj.updated_location_position_accuracy, "altitude": obj.updated_location_altitude, "altitude_accuracy": obj.updated_location_altitude_accuracy, "heading": obj.updated_location_heading, "speed": obj.updated_location_speed, "location_recorded_dt": updated_location_dt_value } def get_assigned_to(self, case: models.Case): if case.assigned_to_id: if case.assigned_at is not None: assigned_at = case.assigned_at.isoformat() else: assigned_at = None return { 'id': case.assigned_to.id, 'email': case.assigned_to.email, 'username': case.assigned_to.username, 'name': case.assigned_to.name, 'color': case.assigned_to.color, 'assigned_at': assigned_at }
class WalletSchema(Schema): uri = fields.String( required=True, dump_only=True, format='uri-reference', description=URI_DESCRIPTION, example='/creditors/2/wallet', ) type = fields.Function( lambda obj: type_registry.wallet, required=True, type='string', description='The type of this object.', example='Wallet', ) creditor = fields.Nested( ObjectReferenceSchema, required=True, dump_only=True, description="The URI of the `Creditor`.", example={'uri': '/creditors/2/'}, ) accounts_list = fields.Nested( ObjectReferenceSchema, required=True, dump_only=True, data_key='accountsList', description= "The URI of creditor's `AccountsList`. That is: an URI of a `PaginatedList` of " "`ObjectReference`s to all `Account`s belonging to the creditor. The paginated " "list will not be sorted in any particular order.", example={'uri': '/creditors/2/accounts-list'}, ) log = fields.Nested( PaginatedStreamSchema, required=True, dump_only=True, description= "A `PaginatedStream` of creditor's `LogEntry`s. The paginated stream will be " "sorted in chronological order (smaller entry IDs go first). The main " "purpose of the log stream is to allow the clients of the API to reliably " "and efficiently invalidate their caches, simply by following the \"log\".", example={ 'first': '/creditors/2/log', 'forthcoming': '/creditors/2/log?prev=12345', 'itemsType': 'LogEntry', 'type': 'PaginatedStream', }, ) log_retention_days = fields.Method( 'get_log_retention_days', required=True, dump_only=True, type='integer', format="int32", data_key='logRetentionDays', description= "The entries in the creditor's log stream will not be deleted for at least this " "number of days. This will always be a positive number.", example=30, ) transfers_list = fields.Nested( ObjectReferenceSchema, required=True, dump_only=True, data_key='transfersList', description= "The URI of creditor's `TransfersList`. That is: an URI of a `PaginatedList` of " "`ObjectReference`s to all `Transfer`s initiated by the creditor, which have not " "been deleted yet. The paginated list will not be sorted in any particular order.", example={'uri': '/creditors/2/transfers-list'}, ) create_account = fields.Nested( ObjectReferenceSchema, required=True, dump_only=True, data_key='createAccount', description= 'A URI to which a `DebtorIdentity` object can be POST-ed to create a new `Account`.', example={'uri': '/creditors/2/accounts/'}, ) create_transfer = fields.Nested( ObjectReferenceSchema, required=True, dump_only=True, data_key='createTransfer', description= 'A URI to which a `TransferCreationRequest` can be POST-ed to ' 'create a new `Transfer`.', example={'uri': '/creditors/2/transfers/'}, ) account_lookup = fields.Nested( ObjectReferenceSchema, required=True, dump_only=True, data_key='accountLookup', description= "A URI to which the recipient account's `AccountIdentity` can be POST-ed, " "trying to find the identify of the account's debtor. If the debtor has " "been identified successfully, the response will contain the debtor's " "`DebtorIdentity`. Otherwise, the response code will be 422.", example={'uri': '/creditors/2/account-lookup'}, ) debtor_lookup = fields.Nested( ObjectReferenceSchema, required=True, dump_only=True, data_key='debtorLookup', description= "A URI to which a `DebtorIdentity` object can be POST-ed, trying to find an " "existing account with this debtor. If an existing account is found, the " "response will redirect to the `Account` (response code 303). Otherwise, " "the response will be empty (response code 204).", example={'uri': '/creditors/2/debtor-lookup'}, ) pin_info_reference = fields.Nested( ObjectReferenceSchema, required=True, dump_only=True, data_key='pinInfo', description="The URI of creditor's `PinInfo`.", example={'uri': '/creditors/2/pin'}, ) require_pin = fields.Boolean( required=True, dump_only=True, data_key='requirePin', description= "Whether the PIN is required for potentially dangerous operations." "\n\n" "**Note:** The PIN will never be required when in \"PIN reset\" mode.", example=True, ) def get_log_retention_days(self, obj): calc_log_retention_days = self.context['calc_log_retention_days'] days = calc_log_retention_days(obj.creditor_id) assert days > 0 return days @pre_dump def process_creditor_instance(self, obj, many): assert isinstance(obj, models.Creditor) paths = self.context['paths'] calc_require_pin = self.context['calc_require_pin'] obj = copy(obj) obj.uri = paths.wallet(creditorId=obj.creditor_id) obj.creditor = {'uri': paths.creditor(creditorId=obj.creditor_id)} obj.accounts_list = { 'uri': paths.accounts_list(creditorId=obj.creditor_id) } obj.transfers_list = { 'uri': paths.transfers_list(creditorId=obj.creditor_id) } obj.account_lookup = { 'uri': paths.account_lookup(creditorId=obj.creditor_id) } obj.debtor_lookup = { 'uri': paths.debtor_lookup(creditorId=obj.creditor_id) } obj.create_account = { 'uri': paths.accounts(creditorId=obj.creditor_id) } obj.create_transfer = { 'uri': paths.transfers(creditorId=obj.creditor_id) } obj.pin_info_reference = { 'uri': paths.pin_info(creditorId=obj.creditor_id) } obj.require_pin = calc_require_pin(obj.pin_info) log_path = paths.log_entries(creditorId=obj.creditor_id) obj.log = { 'items_type': type_registry.log_entry, 'first': log_path, 'forthcoming': f'{log_path}?prev={obj.last_log_entry_id}', } return obj
class RecordSchemaOpenAIREJSON(Schema): """Schema for records in OpenAIRE-JSON. OpenAIRE Schema: https://www.openaire.eu/schema/1.0/oaf-result-1.0.xsd OpenAIRE Vocabularies: http://api.openaire.eu/vocabularies """ originalId = fields.Method('get_original_id', required=True) title = fields.Str(attribute='metadata.title', required=True) description = fields.Str(attribute='metadata.description') url = fields.Method('get_url', required=True) authors = fields.List(fields.Str(attribute='name'), attribute='metadata.creators') type = fields.Method('get_type') resourceType = fields.Method('get_resource_type', required=True) language = fields.Str(attribute='metadata.language') licenseCode = fields.Method('get_license_code', required=True) embargoEndDate = DateString(attribute='metadata.embargo_date') publisher = fields.Method('get_publisher') collectedFromId = fields.Method('get_datasource_id', required=True) hostedById = fields.Method('get_datasource_id') linksToProjects = fields.Method('get_links_to_projects') pids = fields.Method('get_pids') def _openaire_type(self, obj): return ObjectType.get_by_dict( obj.get('metadata', {}).get('resource_type') ).get('openaire') def get_original_id(self, obj): """Get Original Id.""" oatype = self._openaire_type(obj) if oatype: return openaire_original_id( obj.get('metadata', {}), oatype['type'] )[1] return missing def get_type(self, obj): """Get record type.""" oatype = self._openaire_type(obj) if oatype: return oatype['type'] return missing def get_resource_type(self, obj): """Get resource type.""" oatype = self._openaire_type(obj) if oatype: return oatype['resourceType'] return missing def get_datasource_id(self, obj): """Get OpenAIRE datasouce identifier.""" return openaire_datasource_id(obj.get('metadata')) or missing # Mapped from: http://api.openaire.eu/vocabularies/dnet:access_modes LICENSE_MAPPING = { 'open': 'OPEN', 'embargoed': 'EMBARGO', 'restricted': 'RESTRICTED', 'closed': 'CLOSED', } def get_license_code(self, obj): """Get license code.""" metadata = obj.get('metadata') return self.LICENSE_MAPPING.get( metadata.get('access_right'), 'UNKNOWN') def get_links_to_projects(self, obj): """Get project/grant links.""" metadata = obj.get('metadata') grants = metadata.get('grants', []) links = [] for grant in grants: eurepo = grant.get('identifiers', {}).get('eurepo', '') if eurepo: links.append('{eurepo}/{title}/{acronym}'.format( eurepo=eurepo, title=grant.get('title', '').replace('/', '%2F'), acronym=grant.get('acronym', ''))) return links or missing def get_pids(self, obj): """Get record PIDs.""" metadata = obj.get('metadata') pids = [{'type': 'oai', 'value': metadata['_oai']['id']}] if 'doi' in metadata: pids.append({'type': 'doi', 'value': metadata['doi']}) return pids def get_url(self, obj): """Get record URL.""" return current_app.config['ZENODO_RECORDS_UI_LINKS_FORMAT'].format( recid=obj['metadata']['recid']) def get_publisher(self, obj): """Get publisher.""" m = obj['metadata'] imprint_publisher = m.get('imprint', {}).get('publisher') if imprint_publisher: return imprint_publisher part_publisher = m.get('part_of', {}).get('publisher') if part_publisher: return part_publisher if m.get('doi', '').startswith('10.5281/'): return 'Zenodo' return missing
class LogEntrySchema(Schema): type = fields.Function( lambda obj: type_registry.log_entry, required=True, type='string', description='The type of this object.', example='LogEntry', ) entry_id = fields.Integer( required=True, dump_only=True, format='int64', data_key='entryId', description= 'The ID of the log entry. This will always be a positive number. The first ' 'log entry has an ID of `1`, and the ID of each subsequent log entry will ' 'be equal to the ID of the previous log entry plus one.', example=12345, ) added_at = fields.DateTime( required=True, dump_only=True, data_key='addedAt', description='The moment at which the entry was added to the log.', ) object_type = fields.Method( 'get_object_type', required=True, dump_only=True, type='string', data_key='objectType', description= 'The type of the object that has been created, updated, or deleted.', example='Account', ) object = fields.Nested( ObjectReferenceSchema, required=True, dump_only=True, description= 'The URI of the object that has been created, updated, or deleted.', example={'uri': '/creditors/2/accounts/1/'}, ) is_deleted = fields.Function( lambda obj: bool(obj.is_deleted), required=True, dump_only=True, type='boolean', data_key='deleted', description='Whether the object has been deleted.', example=False, ) optional_object_update_id = fields.Integer( dump_only=True, data_key='objectUpdateId', format='int64', description= 'A positive number which gets incremented after each change in the ' 'object. When this field is not present, this means that the changed object ' 'does not have an update ID (the object is immutable, or has been deleted, ' 'for example).', example=10, ) optional_data = fields.Dict( dump_only=True, data_key='data', description= 'Optional information about the new state of the created/updated object. When ' 'present, this information can be used to avoid making a network request to ' 'obtain the new state. What properties the "data" object will have, depends ' 'on the value of the `objectType` field:' '\n\n' '### When the object type is "AccountLedger"\n' '`principal` and `nextEntryId` properties will be present.' '\n\n' '### When the object type is "Transfer"\n' 'If the transfer is finalized, `finalizedAt` and (only when there is an ' 'error) `errorCode` properties will be present. If the transfer is not ' 'finalized, the "data" object will not be present.' '\n\n' '**Note:** This field will never be present when the object has been deleted.', ) @pre_dump def process_log_entry_instance(self, obj, many): assert isinstance(obj, models.LogEntry) obj = copy(obj) obj.object = {'uri': obj.get_object_uri(self.context['paths'])} if obj.object_update_id is not None: obj.optional_object_update_id = obj.object_update_id if not obj.is_deleted: data = obj.get_data_dict() if data is not None: obj.optional_data = data return obj def get_object_type(self, obj): return obj.get_object_type(self.context['types'])
class CredentialSchema(AutoSchema): _id = fields.Integer(dump_only=True, attribute='id') _rev = fields.String(default='', dump_only=True) owned = fields.Boolean(default=False, dump_only=True) owner = fields.String(dump_only=True, attribute='creator.username', default='') username = fields.String(default='', required=True, validate=validate.Length( min=1, error="Username must be defined")) password = fields.String(default='') description = fields.String(default='') couchdbid = fields.String(default='') # backwards compatibility parent_type = MutableField(fields.Method('get_parent_type'), fields.String(), required=True) parent = MutableField(fields.Method('get_parent'), fields.Integer(), required=True) host_ip = fields.String(dump_only=True, attribute="host.ip", default=None) service_name = fields.String(dump_only=True, attribute="service.name", default=None) target = fields.Method('get_target', dump_only=True) # for filtering host_id = fields.Integer(load_only=True) service_id = fields.Integer(load_only=True) metadata = SelfNestedField(MetadataSchema()) def get_parent(self, obj): return obj.host_id or obj.service_id def get_parent_type(self, obj): assert obj.host_id is not None or obj.service_id is not None return 'Service' if obj.service_id is not None else 'Host' def get_target(self, obj): if obj.host is not None: return obj.host.ip else: return obj.service.host.ip + '/' + obj.service.name class Meta: model = Credential fields = ('id', '_id', "_rev", 'parent', 'username', 'description', 'name', 'password', 'owner', 'owned', 'couchdbid', 'parent', 'parent_type', 'metadata', 'host_ip', 'service_name', 'target') @post_load def set_parent(self, data): parent_type = data.pop('parent_type', None) parent_id = data.pop('parent', None) if parent_type == 'Host': parent_class = Host parent_field = 'host_id' not_parent_field = 'service_id' elif parent_type == 'Service': parent_class = Service parent_field = 'service_id' not_parent_field = 'host_id' else: raise ValidationError( 'Unknown parent type: {}'.format(parent_type)) try: parent = db.session.query(parent_class).join(Workspace).filter( Workspace.name == self.context['workspace_name'], parent_class.id == parent_id).one() except NoResultFound: raise InvalidUsage('Parent id not found: {}'.format(parent_id)) data[parent_field] = parent.id data[not_parent_field] = None return data
def test_method_field_passed_deserialize_only_is_load_only(self): field = fields.Method(deserialize="somemethod") assert field.load_only is True assert field.dump_only is False
class SearchTextSchema(Schema): dpid = fields.String() dpgroupid = fields.String() title = fields.String(required=True) creator = fields.String() digitized_publisher = fields.String() ndc = fields.String() from_date = fields.String(load_to='from', dump_to='from') until_date = fields.String(load_to='until', dump_to='until') cnt = fields.Integer(validate=lambda n: 1 <= n <= 500) idx = fields.Integer(validate=lambda n: 1 <= n <= 500) isbn = fields.String() mediatype = fields.Method("validate_mediatype") def validate_mediatype(self, obj: dict) -> str: """mediatypeのバリデーションとシリアライズ Args: obj (dict): リクエストから送られてきた内容の辞書型 Returns: str: mediatypeの文字列型 """ if 'mediatype' in obj: mediatype = obj['mediatype'] mediatype_str = "" if type(mediatype) is list: for i, v in enumerate(mediatype): mediatype_str += '{} '.format(v) mediatype_str = mediatype_str.rstrip() elif (type(mediatype) is int) and (1 <= mediatype <= 9): mediatype_str = str(mediatype_str) else: # 範囲外の場合は強制的に1に変換する mediatype_str = str(1) else: mediatype_str = str(1) return mediatype_str @validates('from_date') def validate_from_date(self, value) -> None: """ form_dateのバリデーションチェック Args: value ([type]): [description] Raises: ValidationError: [description] ValidationError: [description] ValidationError: [description] ValidationError: [description] Returns: None: [description] """ # YYYY, YYYY-MM, YYYY-MM-DD value_list = value.split('-') try: for i, date in enumerate(value_list): if i == 0: year = int(date) if not (1970 <= year <= 9999): raise ValidationError( 'from_date year is not date type') elif i == 1: month = int(date) if not (1 <= month <= 12): raise ValidationError( 'from_date month is not date type') elif i == 2: day = int(date) if not (1 <= day <= 31): raise ValidationError('from_date day is not date type') except Exception: raise ValidationError('from_date is not date type') @validates('until_date') def validate_until_date(self, value): """until_dateのバリデーションチェック Args: value ([type]): [description] Raises: ValidationError: [description] ValidationError: [description] ValidationError: [description] ValidationError: [description] """ # YYYY、YYYY-MM、YYYY-MM-DD value_list = value.split('-') try: for i, date in enumerate(value_list): if i == 0: year = int(date) if not (1970 <= year <= 9999): raise ValidationError( 'until_date year is not date type') elif i == 1: month = int(date) if not (1 <= month <= 12): raise ValidationError( 'until_date month is not date type') elif i == 2: day = int(date) if not (1 <= day <= 31): raise ValidationError( 'until_date day is not date type') except Exception: raise ValidationError('until_date is not date type')
class BadSerializer(Schema): foo = "not callable" bad_field = fields.Method("foo")
def test_method_field_passed_serialize_only_is_dump_only(self, user): field = fields.Method(serialize="method") assert field.dump_only is True assert field.load_only is False
class CommonRecordSchemaV1(Schema, StrictKeysMixin): """Common record schema.""" id = fields.Integer(attribute='pid.pid_value', dump_only=True) conceptrecid = SanitizedUnicode(attribute='metadata.conceptrecid', dump_only=True) doi = SanitizedUnicode(attribute='metadata.doi', dump_only=True) conceptdoi = SanitizedUnicode(attribute='metadata.conceptdoi', dump_only=True) links = fields.Method('dump_links', dump_only=True) created = fields.Str(dump_only=True) @pre_dump() def predump_relations(self, obj): """Add relations to the schema context.""" m = obj.get('metadata', {}) if 'relations' not in m: pid = self.context['pid'] # For deposits serialize the record's relations if is_deposit(m): pid = PersistentIdentifier.get('recid', m['recid']) m['relations'] = serialize_relations(pid) # Remove some non-public fields if is_record(m): version_info = m['relations'].get('version', []) if version_info: version_info[0].pop('draft_child_deposit', None) def dump_links(self, obj): """Dump links.""" links = obj.get('links', {}) if current_app: links.update(self._dump_common_links(obj)) try: m = obj.get('metadata', {}) if is_deposit(m): links.update(self._dump_deposit_links(obj)) else: links.update(self._dump_record_links(obj)) except BuildError: pass return links def _thumbnail_url(self, fileobj, thumbnail_size): """Create the thumbnail URL for an image.""" return link_for( current_app.config.get('THEME_SITEURL'), 'thumbnail', path=ui_iiif_image_url( fileobj, size='{},'.format(thumbnail_size), image_format='png' if fileobj['type'] == 'png' else 'jpg', )) def _thumbnail_urls(self, recid): """Create the thumbnail URL for an image.""" thumbnail_urls = {} cached_sizes = current_app.config.get('CACHED_THUMBNAILS') for size in cached_sizes: thumbnail_urls[size] = link_for( current_app.config.get('THEME_SITEURL'), 'thumbs', id=recid, size=size) return thumbnail_urls def _dump_common_links(self, obj): """Dump common links for deposits and records.""" links = {} m = obj.get('metadata', {}) doi = m.get('doi') if doi: links['badge'] = ui_link_for('badge', doi=quote(doi)) links['doi'] = idutils.to_url(doi, 'doi', 'https') conceptdoi = m.get('conceptdoi') if conceptdoi: links['conceptbadge'] = ui_link_for('badge', doi=quote(conceptdoi)) links['conceptdoi'] = idutils.to_url(conceptdoi, 'doi', 'https') files = m.get('_files', []) for f in files: if f.get('type') in thumbnail_exts: try: # First previewable image is used for preview. links['thumbs'] = self._thumbnail_urls(m.get('recid')) links['thumb250'] = self._thumbnail_url(f, 250) except RuntimeError: pass break return links def _dump_record_links(self, obj): """Dump record-only links.""" links = {} m = obj.get('metadata') bucket_id = m.get('_buckets', {}).get('record') recid = m.get('recid') if bucket_id: links['bucket'] = api_link_for('bucket', bucket=bucket_id) links['html'] = ui_link_for('record_html', id=recid) # Generate relation links links.update(self._dump_relation_links(m)) return links def _dump_deposit_links(self, obj): """Dump deposit-only links.""" links = {} m = obj.get('metadata') bucket_id = m.get('_buckets', {}).get('deposit') recid = m.get('recid') is_published = 'pid' in m.get('_deposit', {}) if bucket_id: links['bucket'] = api_link_for('bucket', bucket=bucket_id) # Record links if is_published: links['record'] = api_link_for('record', id=recid) links['record_html'] = ui_link_for('record_html', id=recid) # Generate relation links links.update(self._dump_relation_links(m)) return links def _dump_relation_links(self, metadata): """Dump PID relation links.""" links = {} relations = metadata.get('relations') if relations: version_info = next(iter(relations.get('version', [])), None) if version_info: last_child = version_info.get('last_child') if last_child: links['latest'] = api_link_for('record', id=last_child['pid_value']) links['latest_html'] = ui_link_for( 'record_html', id=last_child['pid_value']) if is_deposit(metadata): draft_child_depid = version_info.get('draft_child_deposit') if draft_child_depid: links['latest_draft'] = api_link_for( 'deposit', id=draft_child_depid['pid_value']) links['latest_draft_html'] = ui_link_for( 'deposit_html', id=draft_child_depid['pid_value']) return links @post_load(pass_many=False) def remove_envelope(self, data): """Post process data.""" # Remove envelope if 'metadata' in data: data = data['metadata'] # Record schema. data['$schema'] = \ 'https://zenodo.org/schemas/deposits/records/record-v1.0.0.json' return data