def test_noncallable_validators( self, member_validator, iterable_validator ): """ Raise :class:`TypeError` if any validators are not callable. """ with pytest.raises(TypeError) as e: deep_iterable(member_validator, iterable_validator) e.match(r"\w* must be callable")
def test_noncallable_validators(self, member_validator, iterable_validator): """ Raise `TypeError` if any validators are not callable. """ with pytest.raises(TypeError) as e: deep_iterable(member_validator, iterable_validator) value = 42 message = "must be callable (got {value} that is a {type_}).".format( value=value, type_=value.__class__) assert message in e.value.args[0] assert value == e.value.args[1] assert message in e.value.msg assert value == e.value.value
class TocTree: """An individual toctree within a document.""" # TODO validate uniqueness of docnames (at least one item) items: List[Union[GlobItem, FileItem, UrlItem]] = attr.ib( validator=deep_iterable(instance_of((GlobItem, FileItem, UrlItem)), instance_of(list))) caption: Optional[str] = attr.ib(None, kw_only=True, validator=optional(instance_of(str))) hidden: bool = attr.ib(True, kw_only=True, validator=instance_of(bool)) maxdepth: int = attr.ib(-1, kw_only=True, validator=instance_of(int)) numbered: Union[bool, int] = attr.ib(False, kw_only=True, validator=instance_of((bool, int))) reversed: bool = attr.ib(False, kw_only=True, validator=instance_of(bool)) titlesonly: bool = attr.ib(False, kw_only=True, validator=instance_of(bool)) def files(self) -> List[str]: return [str(item) for item in self.items if isinstance(item, FileItem)] def globs(self) -> List[str]: return [str(item) for item in self.items if isinstance(item, GlobItem)]
class MessageHeader(object): # pylint: disable=too-many-instance-attributes """Deserialized message header object. :param SerializationVersion version: Message format version, per spec :param ObjectType type: Message content type, per spec :param AlgorithmSuite algorithm: Algorithm to use for encryption :param bytes message_id: Message ID :param Dict[str,str] encryption_context: Dictionary defining encryption context :param Sequence[EncryptedDataKey] encrypted_data_keys: Encrypted data keys :param ContentType content_type: Message content framing type (framed/non-framed) :param int content_aad_length: empty :param int header_iv_length: Bytes in Initialization Vector value found in header :param int frame_length: Length of message frame in bytes """ version = attr.ib(hash=True, validator=instance_of(SerializationVersion)) type = attr.ib(hash=True, validator=instance_of(ObjectType)) algorithm = attr.ib(hash=True, validator=instance_of(Algorithm)) message_id = attr.ib(hash=True, validator=instance_of(bytes)) encryption_context = attr.ib( hash=True, validator=deep_mapping(key_validator=instance_of(six.string_types), value_validator=instance_of(six.string_types)), ) encrypted_data_keys = attr.ib( hash=True, validator=deep_iterable( member_validator=instance_of(EncryptedDataKey))) content_type = attr.ib(hash=True, validator=instance_of(ContentType)) content_aad_length = attr.ib(hash=True, validator=instance_of(six.integer_types)) header_iv_length = attr.ib(hash=True, validator=instance_of(six.integer_types)) frame_length = attr.ib(hash=True, validator=instance_of(six.integer_types))
class DenyRegionsClientSupplier(ClientSupplier): """AWS KMS client supplier that supplies clients for any region except for the specified regions. .. versionadded:: 1.5.0 :param List[str] denied_regions: Regions to deny :param ClientSupplier client_supplier: Client supplier to wrap (optional) """ denied_regions = attr.ib(validator=(deep_iterable( member_validator=instance_of(six.string_types)), value_is_not_a_string)) _client_supplier = attr.ib(default=attr.Factory(DefaultClientSupplier), validator=optional(is_callable())) def __call__(self, region_name): # type: (Union[None, str]) -> BaseClient """Return a client for the requested region. :rtype: BaseClient :raises UnknownRegionError: if a region is requested that is in ``denied_regions`` """ if region_name in self.denied_regions: raise UnknownRegionError( "Unable to provide client for region '{}'".format(region_name)) return self._client_supplier(region_name)
def test_success_member_and_iterable(self, member_validator): """ If both the member and iterable validators succeed, nothing happens. """ iterable_validator = instance_of(list) v = deep_iterable(member_validator, iterable_validator) a = simple_attr("test") v(None, a, [42])
def test_fail_invalid_member(self, member_validator): """ Raise member validator error if an invalid member is found. """ v = deep_iterable(member_validator) a = simple_attr("test") with pytest.raises(TypeError): v(None, a, [42, "42"])
def test_success_member_only(self, member_validator): """ If the member validator succeeds and the iterable validator is not set, nothing happens. """ v = deep_iterable(member_validator) a = simple_attr("test") v(None, a, [42])
class PipelineStage(_ConfigStructure): """CodePipeline stage definition. :param name: Stage name :param actions: Actions to be taken in stage """ name: str = attr.ib(validator=instance_of(str)) actions: Iterable[PipelineAction] = attr.ib(validator=deep_iterable(member_validator=instance_of(PipelineAction)))
class PerceptionSemanticAlignment: """ Represents an alignment between a perception graph and a set of semantic nodes representing concepts. This is used to represent intermediate semantic data passed between new-style learners when describing a perception. """ perception_graph: PerceptionGraph = attrib( validator=instance_of(PerceptionGraph)) semantic_nodes: ImmutableSet[SemanticNode] = attrib( converter=_to_immutableset, validator=deep_iterable(instance_of(SemanticNode))) functional_concept_to_object_concept: ImmutableDict[ FunctionalObjectConcept, ObjectConcept] = attrib( converter=_to_immutabledict, validator=deep_mapping(instance_of(FunctionalObjectConcept), instance_of(ObjectConcept)), default=immutabledict(), ) @staticmethod def create_unaligned( perception_graph: PerceptionGraph ) -> "PerceptionSemanticAlignment": return PerceptionSemanticAlignment(perception_graph, []) def copy_with_updated_graph_and_added_nodes( self, *, new_graph: PerceptionGraph, new_nodes: Iterable[SemanticNode] ) -> "PerceptionSemanticAlignment": if new_graph is self.perception_graph and not new_nodes: return self else: return PerceptionSemanticAlignment( perception_graph=new_graph, semantic_nodes=chain(self.semantic_nodes, new_nodes), ) def copy_with_mapping( self, *, mapping: Mapping[FunctionalObjectConcept, ObjectConcept] ) -> "PerceptionSemanticAlignment": return PerceptionSemanticAlignment( perception_graph=self.perception_graph, semantic_nodes=self.semantic_nodes, functional_concept_to_object_concept=mapping, ) def __attrs_post_init__(self) -> None: for node in self.perception_graph._graph: # pylint:disable=protected-access if isinstance(node, SemanticNode): if node not in self.semantic_nodes: raise RuntimeError( "All semantic nodes appearing in the perception graph must " "also be in semantic_nodes")
def test_fail_invalid_iterable(self): """ Raise iterable validator error if an invalid iterable is found. """ member_validator = instance_of(int) iterable_validator = instance_of(tuple) v = deep_iterable(member_validator, iterable_validator) a = simple_attr("test") with pytest.raises(TypeError): v(None, a, [42])
def test_fail_invalid_member_and_iterable(self, member_validator): """ Raise iterable validator error if both the iterable and a member are invalid. """ iterable_validator = instance_of(tuple) v = deep_iterable(member_validator, iterable_validator) a = simple_attr("test") with pytest.raises(TypeError): v(None, a, [42, "42"])
class MdParserConfig: """Configuration options for the Markdown Parser. Note in the sphinx configuration these option names are prepended with ``myst_`` """ renderer: str = attr.ib( default="sphinx", validator=in_(["sphinx", "html", "docutils"]) ) commonmark_only: bool = attr.ib(default=False, validator=instance_of(bool)) dmath_enable: bool = attr.ib(default=True, validator=instance_of(bool)) dmath_allow_labels: bool = attr.ib(default=True, validator=instance_of(bool)) dmath_allow_space: bool = attr.ib(default=True, validator=instance_of(bool)) dmath_allow_digits: bool = attr.ib(default=True, validator=instance_of(bool)) amsmath_enable: bool = attr.ib(default=False, validator=instance_of(bool)) deflist_enable: bool = attr.ib(default=False, validator=instance_of(bool)) update_mathjax: bool = attr.ib(default=True, validator=instance_of(bool)) admonition_enable: bool = attr.ib(default=False, validator=instance_of(bool)) figure_enable: bool = attr.ib(default=False, validator=instance_of(bool)) disable_syntax: List[str] = attr.ib( factory=list, validator=deep_iterable(instance_of(str), instance_of((list, tuple))), ) html_img_enable: bool = attr.ib(default=False, validator=instance_of(bool)) # see https://en.wikipedia.org/wiki/List_of_URI_schemes url_schemes: Optional[List[str]] = attr.ib( default=None, validator=optional(deep_iterable(instance_of(str), instance_of((list, tuple)))), ) heading_anchors: Optional[int] = attr.ib( default=None, validator=optional(in_([1, 2, 3, 4, 5, 6, 7])) ) def as_dict(self, dict_factory=dict) -> dict: return attr.asdict(self, dict_factory=dict_factory)
def test_repr_member_only(self): """ Returned validator has a useful `__repr__` when only member validator is set. """ member_validator = instance_of(int) member_repr = "<instance_of validator for type <class 'int'>>" v = deep_iterable(member_validator) expected_repr = ( "<deep_iterable validator for iterables of {member_repr}>").format( member_repr=member_repr) assert expected_repr == repr(v)
class KeyringTrace(object): """Record of all actions that a KeyRing performed with a wrapping key. .. versionadded:: 1.5.0 :param MasterKeyInfo wrapping_key: Wrapping key used :param Set[KeyringTraceFlag] flags: Actions performed """ wrapping_key = attr.ib(validator=instance_of(MasterKeyInfo)) flags = attr.ib(validator=deep_iterable( member_validator=instance_of(KeyringTraceFlag)))
def list_of(type_: type) -> Callable: """ An attr validator that performs validation of list values. :param type_: The type to check for, can be a type or tuple of types :raises TypeError: raises a `TypeError` if the initializer is called with a wrong type for this particular attribute :return: An attr validator that performs validation of values of a list. """ return deep_iterable(member_validator=instance_of(type_), iterable_validator=instance_of(list))
def test_repr_member_and_iterable(self): """ Returned validator has a useful `__repr__` when both member and iterable validators are set. """ member_validator = instance_of(int) member_repr = "<instance_of validator for type <class 'int'>>" iterable_validator = instance_of(list) iterable_repr = "<instance_of validator for type <class 'list'>>" v = deep_iterable(member_validator, iterable_validator) expected_repr = ( "<deep_iterable validator for" " {iterable_repr} iterables of {member_repr}>").format( iterable_repr=iterable_repr, member_repr=member_repr) assert expected_repr == repr(v)
def test_repr_member_only_sequence(self): """ Returned validator has a useful `__repr__` when only member validator is set and the member validator is a list of validators """ member_validator = [always_pass, instance_of(int)] member_repr = ( "_AndValidator(_validators=({func}, " "<instance_of validator for type <class 'int'>>))").format( func=repr(always_pass)) v = deep_iterable(member_validator) expected_repr = ( "<deep_iterable validator for iterables of {member_repr}>").format( member_repr=member_repr) assert expected_repr == repr(v)
class TokenSequenceLinguisticDescription(LinguisticDescription): """ A `LinguisticDescription` which consists of a sequence of tokens. """ tokens: Tuple[str, ...] = attrib(converter=_to_tuple, validator=deep_iterable(instance_of(str))) def as_token_sequence(self) -> Tuple[str, ...]: return self.tokens def __getitem__(self, item) -> str: return self.tokens[item] def __len__(self) -> int: return len(self.tokens)
class ByHierarchyAndProperties(OntologyNodeSelector): """ An `OntologyNodeSelector` which selects all nodes which are descendents of *descendents_of*, which possess all of *required_properties*, and which possess none of *banned_properties*. """ _descendents_of: OntologyNode = attrib(validator=instance_of(OntologyNode)) _required_properties: ImmutableSet[OntologyNode] = attrib( converter=_to_immutableset, default=immutableset()) _banned_properties: ImmutableSet[OntologyNode] = attrib( converter=_to_immutableset, default=immutableset()) _banned_ontology_types: ImmutableSet[OntologyNode] = attrib( converter=_to_immutableset, default=immutableset(), validator=deep_iterable(instance_of(OntologyNode)), ) def _select_nodes(self, ontology: Ontology) -> AbstractSet[OntologyNode]: return ontology.nodes_with_properties( self._descendents_of, self._required_properties, banned_properties=self._banned_properties, banned_ontology_types=self._banned_ontology_types, ) def __repr__(self) -> str: required_properties = [ f"+{property_}" for property_ in self._required_properties ] banned_properties = [ f"-{property_}" for property_ in self._banned_properties ] banned_ontology_types = [ f"-{ontology_type}" for ontology_type in self._banned_ontology_types ] property_string: str if required_properties or banned_properties or banned_ontology_types: property_string = f"[{', '.join(chain(required_properties, banned_properties, banned_ontology_types))}]" else: property_string = "" return f"ancestorIs({self._descendents_of.handle}){property_string})"
class Document: """A document in the site map.""" docname: str = attr.ib(validator=instance_of(str)) title: Optional[str] = attr.ib(None, validator=optional(instance_of(str))) # TODO validate uniqueness of docnames across all parts (and none should be the docname) subtrees: List[TocTree] = attr.ib(factory=list, validator=deep_iterable( instance_of(TocTree), instance_of(list))) def child_files(self) -> List[str]: """Return all children files.""" return [name for tree in self.subtrees for name in tree.files()] def child_globs(self) -> List[str]: """Return all children globs.""" return [name for tree in self.subtrees for name in tree.globs()]
class DecryptionMaterialsRequest(object): """Request object to provide to a crypto material manager's `decrypt_materials` method. .. versionadded:: 1.3.0 :param algorithm: Algorithm to provide to master keys for underlying decrypt requests :type algorithm: aws_encryption_sdk.identifiers.Algorithm :param encrypted_data_keys: Set of encrypted data keys :type encrypted_data_keys: set of `aws_encryption_sdk.structures.EncryptedDataKey` :param dict encryption_context: Encryption context to provide to master keys for underlying decrypt requests """ algorithm = attr.ib(validator=instance_of(Algorithm)) encrypted_data_keys = attr.ib(validator=deep_iterable(member_validator=instance_of(EncryptedDataKey))) encryption_context = attr.ib( validator=deep_mapping( key_validator=instance_of(six.string_types), value_validator=instance_of(six.string_types) ) )
def test_repr_sequence_member_and_iterable(self): """ Returned validator has a useful `__repr__` when both member and iterable validators are set and the member validator is a list of validators """ member_validator = [always_pass, instance_of(int)] member_repr = ( "_AndValidator(_validators=({func}, " "<instance_of validator for type <class 'int'>>))").format( func=repr(always_pass)) iterable_validator = instance_of(list) iterable_repr = "<instance_of validator for type <class 'list'>>" v = deep_iterable(member_validator, iterable_validator) expected_repr = ( "<deep_iterable validator for" " {iterable_repr} iterables of {member_repr}>").format( iterable_repr=iterable_repr, member_repr=member_repr) assert expected_repr == repr(v)
def attrib_instance_list(type_: type) -> attr._make._CountingAttr: """ Create a new attr attribute with validator for the given type_ All attributes created are expected to be List of the given type_ :param validator_type: Inform which type(s) the List should validate. When not defined the param type_ will be used the type to be validated. """ # Config validator _validator = deep_iterable( member_validator=instance_of(type_), iterable_validator=instance_of(list), ) metadata = {"type": "instance_list", "class_": type_} return attr.ib( default=attr.Factory(list), validator=_validator, type=List[type_], metadata=metadata, )
class _AwsKmsDiscoveryKeyring(Keyring): """AWS KMS discovery keyring that will attempt to decrypt any AWS KMS encrypted data key. This keyring should never be used directly. It should only ever be used internally by :class:`AwsKmsKeyring`. .. versionadded:: 1.5.0 :param ClientSupplier client_supplier: Client supplier to use when asking for clients :param List[str] grant_tokens: AWS KMS grant tokens to include in requests (optional) """ _client_supplier = attr.ib(validator=is_callable()) _grant_tokens = attr.ib( default=attr.Factory(tuple), validator=(deep_iterable(member_validator=instance_of(six.string_types)), value_is_not_a_string), ) def on_encrypt(self, encryption_materials): # type: (EncryptionMaterials) -> EncryptionMaterials return encryption_materials def on_decrypt(self, decryption_materials, encrypted_data_keys): # type: (DecryptionMaterials, Iterable[EncryptedDataKey]) -> DecryptionMaterials new_materials = decryption_materials for edk in encrypted_data_keys: if new_materials.data_encryption_key is not None: return new_materials if edk.key_provider.provider_id == KEY_NAMESPACE: new_materials = _try_aws_kms_decrypt( client_supplier=self._client_supplier, decryption_materials=new_materials, grant_tokens=self._grant_tokens, encrypted_data_key=edk, ) return new_materials
class CryptoResult(object): """Result container for one-shot cryptographic API results. .. versionadded:: 1.5.0 .. note:: For backwards compatibility, this container also unpacks like a 2-member tuple. This allows for backwards compatibility with the previous outputs. :param bytes result: Binary results of the cryptographic operation :param MessageHeader header: Encrypted message metadata :param Tuple[KeyringTrace] keyring_trace: Keyring trace entries """ result = attr.ib(validator=instance_of(bytes)) header = attr.ib(validator=instance_of(MessageHeader)) keyring_trace = attr.ib(validator=deep_iterable( member_validator=instance_of(KeyringTrace))) def __attrs_post_init__(self): """Construct the inner tuple for backwards compatibility.""" self._legacy_container = (self.result, self.header) def __len__(self): """Emulate the inner tuple.""" return self._legacy_container.__len__() def __iter__(self): """Emulate the inner tuple.""" return self._legacy_container.__iter__() def __getitem__(self, key): """Emulate the inner tuple.""" return self._legacy_container.__getitem__(key)
def attrib_instance_list( type_, *, validator_type: Optional[Tuple[Any, ...]] = None) -> attr._make._CountingAttr: """ Create a new attr attribute with validator for the given type_ All attributes created are expected to be List of the given type_ :param validator_type: Inform which type(s) the List should validate. When not defined the param type_ will be used the type to be validated. """ # Config validator _validator_type = validator_type or type_ _validator = deep_iterable( member_validator=instance_of(_validator_type), iterable_validator=instance_of(list), ) return attr.ib(default=attr.Factory(list), validator=_validator, type=List[type_])
class MdParserConfig: """Configuration options for the Markdown Parser. Note in the sphinx configuration these option names are prepended with ``myst_`` """ renderer: str = attr.ib(default="sphinx", validator=in_(["sphinx", "html", "docutils"])) commonmark_only: bool = attr.ib(default=False, validator=instance_of(bool)) dmath_allow_labels: bool = attr.ib(default=True, validator=instance_of(bool)) dmath_allow_space: bool = attr.ib(default=True, validator=instance_of(bool)) dmath_allow_digits: bool = attr.ib(default=True, validator=instance_of(bool)) update_mathjax: bool = attr.ib(default=True, validator=instance_of(bool)) # TODO remove deprecated _enable attributes after v0.13.0 admonition_enable: bool = attr.ib(default=False, validator=instance_of(bool), repr=False) figure_enable: bool = attr.ib(default=False, validator=instance_of(bool), repr=False) dmath_enable: bool = attr.ib(default=False, validator=instance_of(bool), repr=False) amsmath_enable: bool = attr.ib(default=False, validator=instance_of(bool), repr=False) deflist_enable: bool = attr.ib(default=False, validator=instance_of(bool), repr=False) html_img_enable: bool = attr.ib(default=False, validator=instance_of(bool), repr=False) colon_fence_enable: bool = attr.ib(default=False, validator=instance_of(bool), repr=False) enable_extensions: Iterable[str] = attr.ib(factory=lambda: ["dollarmath"]) @enable_extensions.validator def check_extensions(self, attribute, value): if not isinstance(value, Iterable): raise TypeError(f"myst_enable_extensions not iterable: {value}") diff = set(value).difference([ "dollarmath", "amsmath", "deflist", "html_image", "colon_fence", "smartquotes", "replacements", "linkify", "substitution", ]) if diff: raise ValueError(f"myst_enable_extensions not recognised: {diff}") disable_syntax: List[str] = attr.ib( factory=list, validator=deep_iterable(instance_of(str), instance_of((list, tuple))), ) # see https://en.wikipedia.org/wiki/List_of_URI_schemes url_schemes: Optional[List[str]] = attr.ib( default=None, validator=optional( deep_iterable(instance_of(str), instance_of((list, tuple)))), ) heading_anchors: Optional[int] = attr.ib(default=None, validator=optional( in_([1, 2, 3, 4, 5, 6, 7]))) substitutions: Dict[str, str] = attr.ib( factory=dict, validator=deep_mapping(instance_of(str), instance_of( (str, int, float)), instance_of(dict)), repr=lambda v: str(list(v)), ) sub_delimiters: Tuple[str, str] = attr.ib(default=("{", "}")) @sub_delimiters.validator def check_sub_delimiters(self, attribute, value): if (not isinstance(value, (tuple, list))) or len(value) != 2: raise TypeError( f"myst_sub_delimiters is not a tuple of length 2: {value}") for delim in value: if (not isinstance(delim, str)) or len(delim) != 1: raise TypeError( f"myst_sub_delimiters does not contain strings of length 1: {value}" ) def as_dict(self, dict_factory=dict) -> dict: return attr.asdict(self, dict_factory=dict_factory)
class Param: """Specification of a parameter that is to be constrained.""" name = attr.ib() _min: numeric = attr.ib(-np.inf) _max: numeric = attr.ib(np.inf) prior = attr.ib( kw_only=True, validator=vld.optional(vld.instance_of(stats.distributions.rv_frozen)), ) fiducial = attr.ib( None, type=float, converter=cnv.optional(float), validator=vld.optional(vld.instance_of(float)), kw_only=True, ) latex = attr.ib(kw_only=True) ref = attr.ib(kw_only=True) determines = attr.ib( converter=tuplify, kw_only=True, validator=vld.deep_iterable(vld.instance_of(str)), ) transforms = attr.ib(converter=tuplify, kw_only=True) @latex.default def _ltx_default(self): return texify(self.name) @ref.default def _ref_default(self): return self.prior @prior.default def _prior_default(self) -> stats.distributions.rv_frozen | None: if np.isinf(self._min) or np.isinf(self._max): return None return stats.uniform(self._min, self._max - self._min) @determines.default def _determines_default(self): return (self.name,) @transforms.default def _transforms_default(self): return (None,) * len(self.determines) @transforms.validator def _transforms_validator(self, attribute, value): for val in value: if val is not None and not callable(val): raise TypeError("transforms must be a list of callables") @property def min(self) -> float: """The minimum boundary of the prior, helpful for constraints.""" if self.prior is None: return self._min elif isinstance(self.prior, type(stats.uniform(0, 1))): return self.prior.support()[0] else: return -np.inf @property def max(self) -> float: """The maximum boundary of the prior, helpful for constraints.""" if self.prior is None: return self._max elif isinstance(self.prior, type(stats.uniform(0, 1))): return self.prior.support()[1] else: return np.inf @cached_property def is_alias(self): return all(pm is None for pm in self.transforms) @cached_property def is_pure_alias(self): return self.is_alias and len(self.determines) == 1 def transform(self, val): for pm in self.transforms: if pm is None: yield val else: yield pm(val) def generate_ref(self, n=1): if self.ref is None: raise ValueError("Must specify a valid function for ref to generate refs.") try: ref = self.ref.rvs(size=n) except AttributeError: try: ref = self.ref(size=n) except TypeError: raise TypeError( f"parameter '{self.name}' does not have a valid value for ref" ) if np.any(self.prior.pdf(ref) == 0): raise ValueError( f"param {self.name} produced a reference value outside its domain." ) return ref def logprior(self, val): if self.prior is None: if self._min > val or self._max < val: return -np.inf else: return 0 return self.prior.logpdf(val) def clone(self, **kwargs): return attr.evolve(self, **kwargs) def new(self, p: Parameter) -> Param: """Create a new :class:`Param`. Any missing info from this instance filled in by the given instance. """ assert isinstance(p, Parameter) assert self.determines == (p.name,) if len(self.determines) > 1: raise ValueError("Cannot create new Param if it is not just an alias") default_range = (list(self.transform(p.min))[0], list(self.transform(p.max))[0]) return Param( name=self.name, min=max(self._min, min(default_range)), max=min(self._max, max(default_range)), fiducial=self.fiducial if self.fiducial is not None else p.fiducial, latex=self.latex if (self.latex != self.name or self.name != p.name) else p.latex, ref=self.ref or attr.NOTHING, prior=self.prior or attr.NOTHING, determines=self.determines, transforms=self.transforms, ) def __getstate__(self): """Obtain a simple input state of the class that can initialize it.""" out = attr.asdict(self) if self.transforms == (None,): del out["transforms"] if self.ref is None: del out["ref"] if self.determines == (self.name,): del out["determines"] if self.latex == self.name: del out["latex"] return out def as_dict(self): """Simple representation of the class as a dict. No "name" is included in the dict. """ out = self.__getstate__() del out["name"] return out
class ObjectRecognizer: """ The ObjectRecognizer finds object matches in the scene pattern and adds a `ObjectSemanticNodePerceptionPredicate` which can be used to learn additional semantics which relate objects to other objects If applied to a dynamic situation, this will only recognize objects which are present in both the BEFORE and AFTER frames. """ # Because static patterns must be applied to static perceptions # and dynamic patterns to dynamic situations, # we need to store our patterns both ways. _concepts_to_static_patterns: ImmutableDict[ ObjectConcept, PerceptionGraphPattern] = attrib( validator=deep_mapping(instance_of(ObjectConcept), instance_of(PerceptionGraphPattern)), converter=_to_immutabledict, ) _concepts_to_names: ImmutableDict[ObjectConcept, str] = attrib( validator=deep_mapping(instance_of(ObjectConcept), instance_of(str)), converter=_to_immutabledict, ) # We derive these from the static patterns. _concepts_to_dynamic_patterns: ImmutableDict[ ObjectConcept, PerceptionGraphPattern] = attrib(init=False) determiners: ImmutableSet[str] = attrib(converter=_to_immutableset, validator=deep_iterable( instance_of(str))) """ This is a hack to handle determiners. See https://github.com/isi-vista/adam/issues/498 """ _concept_to_num_subobjects: ImmutableDict[Concept, int] = attrib(init=False) """ Used for a performance optimization in match_objects. """ _language_mode: LanguageMode = attrib(validator=instance_of(LanguageMode), kw_only=True) def __attrs_post_init__(self) -> None: non_lowercase_determiners = [ determiner for determiner in self.determiners if determiner.lower() != determiner ] if non_lowercase_determiners: raise RuntimeError( f"All determiners must be specified in lowercase, but got " f"{non_lowercase_determiners}") @staticmethod def for_ontology_types( ontology_types: Iterable[OntologyNode], determiners: Iterable[str], ontology: Ontology, language_mode: LanguageMode, *, perception_generator: HighLevelSemanticsSituationToDevelopmentalPrimitivePerceptionGenerator, ) -> "ObjectRecognizer": ontology_types_to_concepts = { obj_type: ObjectConcept(obj_type.handle) for obj_type in ontology_types } return ObjectRecognizer( concepts_to_static_patterns=_sort_mapping_by_pattern_complexity( immutabledict(( concept, PerceptionGraphPattern.from_ontology_node( obj_type, ontology, perception_generator=perception_generator), ) for (obj_type, concept) in ontology_types_to_concepts.items())), determiners=determiners, concepts_to_names={ concept: obj_type.handle for obj_type, concept in ontology_types_to_concepts.items() }, language_mode=language_mode, ) def match_objects_old( self, perception_graph: PerceptionGraph ) -> PerceptionGraphFromObjectRecognizer: new_style_input = PerceptionSemanticAlignment( perception_graph=perception_graph, semantic_nodes=[]) new_style_output = self.match_objects(new_style_input) return PerceptionGraphFromObjectRecognizer( perception_graph=new_style_output[0].perception_graph, description_to_matched_object_node=new_style_output[1], ) def match_objects( self, perception_semantic_alignment: PerceptionSemanticAlignment, *, post_process: Callable[[PerceptionGraph, AbstractSet[SemanticNode]], Tuple[PerceptionGraph, AbstractSet[SemanticNode]], ] = default_post_process_enrichment, ) -> Tuple[PerceptionSemanticAlignment, Mapping[Tuple[str, ...], ObjectSemanticNode]]: r""" Recognize known objects in a `PerceptionGraph`. The matched portion of the graph will be replaced with an `ObjectSemanticNode`\ s which will inherit all relationships of any nodes internal to the matched portion with any external nodes. This is useful as a pre-processing step before prepositional and verbal learning experiments. """ # pylint: disable=global-statement,invalid-name global cumulative_millis_in_successful_matches_ms global cumulative_millis_in_failed_matches_ms object_nodes: List[Tuple[Tuple[str, ...], ObjectSemanticNode]] = [] perception_graph = perception_semantic_alignment.perception_graph is_dynamic = perception_semantic_alignment.perception_graph.dynamic if is_dynamic: concepts_to_patterns = self._concepts_to_dynamic_patterns else: concepts_to_patterns = self._concepts_to_static_patterns # We special case handling the ground perception # Because we don't want to remove it from the graph, we just want to use it's # Object node as a recognized object. The situation "a box on the ground" # Prompted the need to recognize the ground graph_to_return = perception_graph for node in graph_to_return._graph.nodes: # pylint:disable=protected-access if node == GROUND_PERCEPTION: matched_object_node = ObjectSemanticNode(GROUND_OBJECT_CONCEPT) if LanguageMode.ENGLISH == self._language_mode: object_nodes.append( ((f"{GROUND_OBJECT_CONCEPT.debug_string}", ), matched_object_node)) elif LanguageMode.CHINESE == self._language_mode: object_nodes.append((("di4 myan4", ), matched_object_node)) else: raise RuntimeError("Invalid language_generator") # We construct a fake match which is only the ground perception node subgraph_of_root = subgraph(perception_graph.copy_as_digraph(), [node]) pattern_match = PerceptionGraphPatternMatch( matched_pattern=PerceptionGraphPattern( graph=subgraph_of_root, dynamic=perception_graph.dynamic), graph_matched_against=perception_graph, matched_sub_graph=PerceptionGraph( graph=subgraph_of_root, dynamic=perception_graph.dynamic), pattern_node_to_matched_graph_node=immutabledict(), ) graph_to_return = replace_match_with_object_graph_node( matched_object_node, graph_to_return, pattern_match) candidate_object_subgraphs = extract_candidate_objects( perception_graph) for candidate_object_graph in candidate_object_subgraphs: num_object_nodes = candidate_object_graph.count_nodes_matching( lambda node: isinstance(node, ObjectPerception)) for (concept, pattern) in concepts_to_patterns.items(): # As an optimization, we count how many sub-object nodes # are in the graph and the pattern. # If they aren't the same, the match is impossible # and we can bail out early. if num_object_nodes != self._concept_to_num_subobjects[concept]: continue with Timer(factor=1000) as t: matcher = pattern.matcher(candidate_object_graph, match_mode=MatchMode.OBJECT) pattern_match = first( matcher.matches(use_lookahead_pruning=True), None) if pattern_match: cumulative_millis_in_successful_matches_ms += t.elapsed matched_object_node = ObjectSemanticNode(concept) # We wrap the concept in a tuple because it could in theory be multiple # tokens, # even though currently it never is. if self._language_mode == LanguageMode.ENGLISH: object_nodes.append( ((concept.debug_string, ), matched_object_node)) elif self._language_mode == LanguageMode.CHINESE: if concept.debug_string == "me": object_nodes.append( (("wo3", ), matched_object_node)) elif concept.debug_string == "you": object_nodes.append( (("ni3", ), matched_object_node)) mappings = ( GAILA_PHASE_1_CHINESE_LEXICON. _ontology_node_to_word # pylint:disable=protected-access ) for k, v in mappings.items(): if k.handle == concept.debug_string: debug_string = str(v.base_form) object_nodes.append( ((debug_string, ), matched_object_node)) graph_to_return = replace_match_with_object_graph_node( matched_object_node, graph_to_return, pattern_match) # We match each candidate objects against only one object type. # See https://github.com/isi-vista/adam/issues/627 break else: cumulative_millis_in_failed_matches_ms += t.elapsed if object_nodes: logging.info( "Object recognizer recognized: %s", [concept for (concept, _) in object_nodes], ) logging.info( "object matching: ms in success: %s, ms in failed: %s", cumulative_millis_in_successful_matches_ms, cumulative_millis_in_failed_matches_ms, ) semantic_object_nodes = immutableset(node for (_, node) in object_nodes) post_process_graph, post_process_nodes = post_process( graph_to_return, semantic_object_nodes) return ( perception_semantic_alignment. copy_with_updated_graph_and_added_nodes( new_graph=post_process_graph, new_nodes=post_process_nodes), immutabledict(object_nodes), ) def match_objects_with_language_old( self, language_aligned_perception: LanguageAlignedPerception ) -> LanguageAlignedPerception: if language_aligned_perception.node_to_language_span: raise RuntimeError( "Don't know how to handle a non-empty node-to-language-span") new_style_input = LanguagePerceptionSemanticAlignment( language_concept_alignment=LanguageConceptAlignment( language_aligned_perception.language, node_to_language_span=[]), perception_semantic_alignment=PerceptionSemanticAlignment( perception_graph=language_aligned_perception.perception_graph, semantic_nodes=[], ), ) new_style_output = self.match_objects_with_language(new_style_input) return LanguageAlignedPerception( language=new_style_output.language_concept_alignment.language, perception_graph=new_style_output.perception_semantic_alignment. perception_graph, node_to_language_span=new_style_output.language_concept_alignment. node_to_language_span, ) def match_objects_with_language( self, language_perception_semantic_alignment: LanguagePerceptionSemanticAlignment, *, post_process: Callable[[PerceptionGraph, AbstractSet[SemanticNode]], Tuple[PerceptionGraph, AbstractSet[SemanticNode]], ] = default_post_process_enrichment, ) -> LanguagePerceptionSemanticAlignment: """ Recognize known objects in a `LanguagePerceptionSemanticAlignment`. For each node matched, this will identify the relevant portion of the linguistic input and record the correspondence. The matched portion of the graph will be replaced with an `ObjectSemanticNode` which will inherit all relationships of any nodes internal to the matched portion with any external nodes. This is useful as a pre-processing step before prepositional and verbal learning experiments. """ if (language_perception_semantic_alignment. perception_semantic_alignment.semantic_nodes): raise RuntimeError( "We assume ObjectRecognizer is run first, with no previous " "alignments") ( post_match_perception_semantic_alignment, tokens_to_object_nodes, ) = self.match_objects( language_perception_semantic_alignment. perception_semantic_alignment, post_process=post_process, ) return LanguagePerceptionSemanticAlignment( language_concept_alignment=language_perception_semantic_alignment. language_concept_alignment.copy_with_added_token_alignments( self._align_objects_to_tokens( tokens_to_object_nodes, language_perception_semantic_alignment. language_concept_alignment.language, )), perception_semantic_alignment= post_match_perception_semantic_alignment, ) def _align_objects_to_tokens( self, description_to_object_node: Mapping[Tuple[str, ...], ObjectSemanticNode], language: LinguisticDescription, ) -> Mapping[ObjectSemanticNode, Span]: result: List[Tuple[ObjectSemanticNode, Span]] = [] # We want to ban the same token index from being aligned twice. matched_token_indices: Set[int] = set() for (description_tuple, object_node) in description_to_object_node.items(): if len(description_tuple) != 1: raise RuntimeError( f"Multi-token descriptions are not yet supported:" f"{description_tuple}") description = description_tuple[0] try: end_index_inclusive = language.index(description) except ValueError: # A scene might contain things which are not referred to by the associated language. continue start_index = end_index_inclusive # This is a somewhat language-dependent hack to gobble up preceding determiners. # See https://github.com/isi-vista/adam/issues/498 . if end_index_inclusive > 0: possible_determiner_index = end_index_inclusive - 1 if language[possible_determiner_index].lower( ) in self.determiners: start_index = possible_determiner_index # We record what tokens were covered so we can block the same tokens being used twice. for included_token_index in range(start_index, end_index_inclusive + 1): if included_token_index in matched_token_indices: raise RuntimeError( "We do not currently support the same object " "being mentioned twice in a sentence.") matched_token_indices.add(included_token_index) result.append(( object_node, language.span(start_index, end_index_exclusive=end_index_inclusive + 1), )) return immutabledict(result) @_concepts_to_dynamic_patterns.default def _init_concepts_to_dynamic_patterns( self) -> ImmutableDict[ObjectConcept, PerceptionGraphPattern]: return immutabledict( (concept, static_pattern.copy_with_temporal_scopes(ENTIRE_SCENE)) for (concept, static_pattern) in self._concepts_to_static_patterns.items()) @_concept_to_num_subobjects.default def _init_patterns_to_num_subobjects( self) -> ImmutableDict[ObjectConcept, int]: return immutabledict(( concept, pattern.count_nodes_matching( lambda node: isinstance(node, AnyObjectPerception)), ) for (concept, pattern) in self._concepts_to_static_patterns.items())