Beispiel #1
0
    def test_noncallable_validators(
        self, member_validator, iterable_validator
    ):
        """
        Raise :class:`TypeError` if any validators are not callable.
        """
        with pytest.raises(TypeError) as e:
            deep_iterable(member_validator, iterable_validator)

        e.match(r"\w* must be callable")
Beispiel #2
0
    def test_noncallable_validators(self, member_validator,
                                    iterable_validator):
        """
        Raise `TypeError` if any validators are not callable.
        """
        with pytest.raises(TypeError) as e:
            deep_iterable(member_validator, iterable_validator)
        value = 42
        message = "must be callable (got {value} that is a {type_}).".format(
            value=value, type_=value.__class__)

        assert message in e.value.args[0]
        assert value == e.value.args[1]
        assert message in e.value.msg
        assert value == e.value.value
Beispiel #3
0
class TocTree:
    """An individual toctree within a document."""

    # TODO validate uniqueness of docnames (at least one item)
    items: List[Union[GlobItem, FileItem, UrlItem]] = attr.ib(
        validator=deep_iterable(instance_of((GlobItem, FileItem,
                                             UrlItem)), instance_of(list)))
    caption: Optional[str] = attr.ib(None,
                                     kw_only=True,
                                     validator=optional(instance_of(str)))
    hidden: bool = attr.ib(True, kw_only=True, validator=instance_of(bool))
    maxdepth: int = attr.ib(-1, kw_only=True, validator=instance_of(int))
    numbered: Union[bool, int] = attr.ib(False,
                                         kw_only=True,
                                         validator=instance_of((bool, int)))
    reversed: bool = attr.ib(False, kw_only=True, validator=instance_of(bool))
    titlesonly: bool = attr.ib(False,
                               kw_only=True,
                               validator=instance_of(bool))

    def files(self) -> List[str]:
        return [str(item) for item in self.items if isinstance(item, FileItem)]

    def globs(self) -> List[str]:
        return [str(item) for item in self.items if isinstance(item, GlobItem)]
class MessageHeader(object):
    # pylint: disable=too-many-instance-attributes
    """Deserialized message header object.

    :param SerializationVersion version: Message format version, per spec
    :param ObjectType type: Message content type, per spec
    :param AlgorithmSuite algorithm: Algorithm to use for encryption
    :param bytes message_id: Message ID
    :param Dict[str,str] encryption_context: Dictionary defining encryption context
    :param Sequence[EncryptedDataKey] encrypted_data_keys: Encrypted data keys
    :param ContentType content_type: Message content framing type (framed/non-framed)
    :param int content_aad_length: empty
    :param int header_iv_length: Bytes in Initialization Vector value found in header
    :param int frame_length: Length of message frame in bytes
    """

    version = attr.ib(hash=True, validator=instance_of(SerializationVersion))
    type = attr.ib(hash=True, validator=instance_of(ObjectType))
    algorithm = attr.ib(hash=True, validator=instance_of(Algorithm))
    message_id = attr.ib(hash=True, validator=instance_of(bytes))
    encryption_context = attr.ib(
        hash=True,
        validator=deep_mapping(key_validator=instance_of(six.string_types),
                               value_validator=instance_of(six.string_types)),
    )
    encrypted_data_keys = attr.ib(
        hash=True,
        validator=deep_iterable(
            member_validator=instance_of(EncryptedDataKey)))
    content_type = attr.ib(hash=True, validator=instance_of(ContentType))
    content_aad_length = attr.ib(hash=True,
                                 validator=instance_of(six.integer_types))
    header_iv_length = attr.ib(hash=True,
                               validator=instance_of(six.integer_types))
    frame_length = attr.ib(hash=True, validator=instance_of(six.integer_types))
class DenyRegionsClientSupplier(ClientSupplier):
    """AWS KMS client supplier that supplies clients for any region except for the specified regions.

    .. versionadded:: 1.5.0

    :param List[str] denied_regions: Regions to deny
    :param ClientSupplier client_supplier: Client supplier to wrap (optional)
    """

    denied_regions = attr.ib(validator=(deep_iterable(
        member_validator=instance_of(six.string_types)),
                                        value_is_not_a_string))
    _client_supplier = attr.ib(default=attr.Factory(DefaultClientSupplier),
                               validator=optional(is_callable()))

    def __call__(self, region_name):
        # type: (Union[None, str]) -> BaseClient
        """Return a client for the requested region.

        :rtype: BaseClient
        :raises UnknownRegionError: if a region is requested that is in ``denied_regions``
        """
        if region_name in self.denied_regions:
            raise UnknownRegionError(
                "Unable to provide client for region '{}'".format(region_name))

        return self._client_supplier(region_name)
Beispiel #6
0
 def test_success_member_and_iterable(self, member_validator):
     """
     If both the member and iterable validators succeed, nothing happens.
     """
     iterable_validator = instance_of(list)
     v = deep_iterable(member_validator, iterable_validator)
     a = simple_attr("test")
     v(None, a, [42])
Beispiel #7
0
 def test_fail_invalid_member(self, member_validator):
     """
     Raise member validator error if an invalid member is found.
     """
     v = deep_iterable(member_validator)
     a = simple_attr("test")
     with pytest.raises(TypeError):
         v(None, a, [42, "42"])
Beispiel #8
0
 def test_success_member_only(self, member_validator):
     """
     If the member validator succeeds and the iterable validator is not set,
     nothing happens.
     """
     v = deep_iterable(member_validator)
     a = simple_attr("test")
     v(None, a, [42])
Beispiel #9
0
class PipelineStage(_ConfigStructure):
    """CodePipeline stage definition.

    :param name: Stage name
    :param actions: Actions to be taken in stage
    """

    name: str = attr.ib(validator=instance_of(str))
    actions: Iterable[PipelineAction] = attr.ib(validator=deep_iterable(member_validator=instance_of(PipelineAction)))
Beispiel #10
0
class PerceptionSemanticAlignment:
    """
    Represents an alignment between a perception graph and a set of semantic nodes representing
    concepts.

    This is used to represent intermediate semantic data passed between new-style learners when
    describing a perception.
    """

    perception_graph: PerceptionGraph = attrib(
        validator=instance_of(PerceptionGraph))
    semantic_nodes: ImmutableSet[SemanticNode] = attrib(
        converter=_to_immutableset,
        validator=deep_iterable(instance_of(SemanticNode)))
    functional_concept_to_object_concept: ImmutableDict[
        FunctionalObjectConcept, ObjectConcept] = attrib(
            converter=_to_immutabledict,
            validator=deep_mapping(instance_of(FunctionalObjectConcept),
                                   instance_of(ObjectConcept)),
            default=immutabledict(),
        )

    @staticmethod
    def create_unaligned(
            perception_graph: PerceptionGraph
    ) -> "PerceptionSemanticAlignment":
        return PerceptionSemanticAlignment(perception_graph, [])

    def copy_with_updated_graph_and_added_nodes(
            self, *, new_graph: PerceptionGraph,
            new_nodes: Iterable[SemanticNode]
    ) -> "PerceptionSemanticAlignment":
        if new_graph is self.perception_graph and not new_nodes:
            return self
        else:
            return PerceptionSemanticAlignment(
                perception_graph=new_graph,
                semantic_nodes=chain(self.semantic_nodes, new_nodes),
            )

    def copy_with_mapping(
        self, *, mapping: Mapping[FunctionalObjectConcept, ObjectConcept]
    ) -> "PerceptionSemanticAlignment":
        return PerceptionSemanticAlignment(
            perception_graph=self.perception_graph,
            semantic_nodes=self.semantic_nodes,
            functional_concept_to_object_concept=mapping,
        )

    def __attrs_post_init__(self) -> None:
        for node in self.perception_graph._graph:  # pylint:disable=protected-access
            if isinstance(node, SemanticNode):
                if node not in self.semantic_nodes:
                    raise RuntimeError(
                        "All semantic nodes appearing in the perception graph must "
                        "also be in semantic_nodes")
Beispiel #11
0
 def test_fail_invalid_iterable(self):
     """
     Raise iterable validator error if an invalid iterable is found.
     """
     member_validator = instance_of(int)
     iterable_validator = instance_of(tuple)
     v = deep_iterable(member_validator, iterable_validator)
     a = simple_attr("test")
     with pytest.raises(TypeError):
         v(None, a, [42])
Beispiel #12
0
 def test_fail_invalid_member_and_iterable(self, member_validator):
     """
     Raise iterable validator error if both the iterable
     and a member are invalid.
     """
     iterable_validator = instance_of(tuple)
     v = deep_iterable(member_validator, iterable_validator)
     a = simple_attr("test")
     with pytest.raises(TypeError):
         v(None, a, [42, "42"])
Beispiel #13
0
class MdParserConfig:
    """Configuration options for the Markdown Parser.

    Note in the sphinx configuration these option names are prepended with ``myst_``
    """

    renderer: str = attr.ib(
        default="sphinx", validator=in_(["sphinx", "html", "docutils"])
    )
    commonmark_only: bool = attr.ib(default=False, validator=instance_of(bool))
    dmath_enable: bool = attr.ib(default=True, validator=instance_of(bool))
    dmath_allow_labels: bool = attr.ib(default=True, validator=instance_of(bool))
    dmath_allow_space: bool = attr.ib(default=True, validator=instance_of(bool))
    dmath_allow_digits: bool = attr.ib(default=True, validator=instance_of(bool))
    amsmath_enable: bool = attr.ib(default=False, validator=instance_of(bool))
    deflist_enable: bool = attr.ib(default=False, validator=instance_of(bool))

    update_mathjax: bool = attr.ib(default=True, validator=instance_of(bool))

    admonition_enable: bool = attr.ib(default=False, validator=instance_of(bool))
    figure_enable: bool = attr.ib(default=False, validator=instance_of(bool))

    disable_syntax: List[str] = attr.ib(
        factory=list,
        validator=deep_iterable(instance_of(str), instance_of((list, tuple))),
    )

    html_img_enable: bool = attr.ib(default=False, validator=instance_of(bool))

    # see https://en.wikipedia.org/wiki/List_of_URI_schemes
    url_schemes: Optional[List[str]] = attr.ib(
        default=None,
        validator=optional(deep_iterable(instance_of(str), instance_of((list, tuple)))),
    )

    heading_anchors: Optional[int] = attr.ib(
        default=None, validator=optional(in_([1, 2, 3, 4, 5, 6, 7]))
    )

    def as_dict(self, dict_factory=dict) -> dict:
        return attr.asdict(self, dict_factory=dict_factory)
Beispiel #14
0
 def test_repr_member_only(self):
     """
     Returned validator has a useful `__repr__`
     when only member validator is set.
     """
     member_validator = instance_of(int)
     member_repr = "<instance_of validator for type <class 'int'>>"
     v = deep_iterable(member_validator)
     expected_repr = (
         "<deep_iterable validator for iterables of {member_repr}>").format(
             member_repr=member_repr)
     assert expected_repr == repr(v)
class KeyringTrace(object):
    """Record of all actions that a KeyRing performed with a wrapping key.

    .. versionadded:: 1.5.0

    :param MasterKeyInfo wrapping_key: Wrapping key used
    :param Set[KeyringTraceFlag] flags: Actions performed
    """

    wrapping_key = attr.ib(validator=instance_of(MasterKeyInfo))
    flags = attr.ib(validator=deep_iterable(
        member_validator=instance_of(KeyringTraceFlag)))
Beispiel #16
0
def list_of(type_: type) -> Callable:
    """
    An attr validator that performs validation of list values.

    :param type_: The type to check for, can be a type or tuple of types

    :raises TypeError:
        raises a `TypeError` if the initializer is called with a wrong type for this particular attribute

    :return: An attr validator that performs validation of values of a list.
    """
    return deep_iterable(member_validator=instance_of(type_),
                         iterable_validator=instance_of(list))
Beispiel #17
0
 def test_repr_member_and_iterable(self):
     """
     Returned validator has a useful `__repr__` when both member
     and iterable validators are set.
     """
     member_validator = instance_of(int)
     member_repr = "<instance_of validator for type <class 'int'>>"
     iterable_validator = instance_of(list)
     iterable_repr = "<instance_of validator for type <class 'list'>>"
     v = deep_iterable(member_validator, iterable_validator)
     expected_repr = (
         "<deep_iterable validator for"
         " {iterable_repr} iterables of {member_repr}>").format(
             iterable_repr=iterable_repr, member_repr=member_repr)
     assert expected_repr == repr(v)
Beispiel #18
0
 def test_repr_member_only_sequence(self):
     """
     Returned validator has a useful `__repr__`
     when only member validator is set and the member validator is a list of
     validators
     """
     member_validator = [always_pass, instance_of(int)]
     member_repr = (
         "_AndValidator(_validators=({func}, "
         "<instance_of validator for type <class 'int'>>))").format(
             func=repr(always_pass))
     v = deep_iterable(member_validator)
     expected_repr = (
         "<deep_iterable validator for iterables of {member_repr}>").format(
             member_repr=member_repr)
     assert expected_repr == repr(v)
Beispiel #19
0
class TokenSequenceLinguisticDescription(LinguisticDescription):
    """
    A `LinguisticDescription` which consists of a sequence of tokens.
    """

    tokens: Tuple[str, ...] = attrib(converter=_to_tuple,
                                     validator=deep_iterable(instance_of(str)))

    def as_token_sequence(self) -> Tuple[str, ...]:
        return self.tokens

    def __getitem__(self, item) -> str:
        return self.tokens[item]

    def __len__(self) -> int:
        return len(self.tokens)
Beispiel #20
0
class ByHierarchyAndProperties(OntologyNodeSelector):
    """
    An `OntologyNodeSelector` which selects all nodes
     which are descendents of *descendents_of*,
     which possess all of *required_properties*,
     and which possess none of *banned_properties*.
    """

    _descendents_of: OntologyNode = attrib(validator=instance_of(OntologyNode))
    _required_properties: ImmutableSet[OntologyNode] = attrib(
        converter=_to_immutableset, default=immutableset())
    _banned_properties: ImmutableSet[OntologyNode] = attrib(
        converter=_to_immutableset, default=immutableset())
    _banned_ontology_types: ImmutableSet[OntologyNode] = attrib(
        converter=_to_immutableset,
        default=immutableset(),
        validator=deep_iterable(instance_of(OntologyNode)),
    )

    def _select_nodes(self, ontology: Ontology) -> AbstractSet[OntologyNode]:
        return ontology.nodes_with_properties(
            self._descendents_of,
            self._required_properties,
            banned_properties=self._banned_properties,
            banned_ontology_types=self._banned_ontology_types,
        )

    def __repr__(self) -> str:
        required_properties = [
            f"+{property_}" for property_ in self._required_properties
        ]
        banned_properties = [
            f"-{property_}" for property_ in self._banned_properties
        ]
        banned_ontology_types = [
            f"-{ontology_type}"
            for ontology_type in self._banned_ontology_types
        ]

        property_string: str
        if required_properties or banned_properties or banned_ontology_types:
            property_string = f"[{', '.join(chain(required_properties, banned_properties, banned_ontology_types))}]"
        else:
            property_string = ""

        return f"ancestorIs({self._descendents_of.handle}){property_string})"
Beispiel #21
0
class Document:
    """A document in the site map."""

    docname: str = attr.ib(validator=instance_of(str))
    title: Optional[str] = attr.ib(None, validator=optional(instance_of(str)))
    # TODO validate uniqueness of docnames across all parts (and none should be the docname)
    subtrees: List[TocTree] = attr.ib(factory=list,
                                      validator=deep_iterable(
                                          instance_of(TocTree),
                                          instance_of(list)))

    def child_files(self) -> List[str]:
        """Return all children files."""
        return [name for tree in self.subtrees for name in tree.files()]

    def child_globs(self) -> List[str]:
        """Return all children globs."""
        return [name for tree in self.subtrees for name in tree.globs()]
class DecryptionMaterialsRequest(object):
    """Request object to provide to a crypto material manager's `decrypt_materials` method.

    .. versionadded:: 1.3.0

    :param algorithm: Algorithm to provide to master keys for underlying decrypt requests
    :type algorithm: aws_encryption_sdk.identifiers.Algorithm
    :param encrypted_data_keys: Set of encrypted data keys
    :type encrypted_data_keys: set of `aws_encryption_sdk.structures.EncryptedDataKey`
    :param dict encryption_context: Encryption context to provide to master keys for underlying decrypt requests
    """

    algorithm = attr.ib(validator=instance_of(Algorithm))
    encrypted_data_keys = attr.ib(validator=deep_iterable(member_validator=instance_of(EncryptedDataKey)))
    encryption_context = attr.ib(
        validator=deep_mapping(
            key_validator=instance_of(six.string_types), value_validator=instance_of(six.string_types)
        )
    )
Beispiel #23
0
    def test_repr_sequence_member_and_iterable(self):
        """
        Returned validator has a useful `__repr__` when both member
        and iterable validators are set and the member validator is a list of
        validators
        """
        member_validator = [always_pass, instance_of(int)]
        member_repr = (
            "_AndValidator(_validators=({func}, "
            "<instance_of validator for type <class 'int'>>))").format(
                func=repr(always_pass))
        iterable_validator = instance_of(list)
        iterable_repr = "<instance_of validator for type <class 'list'>>"
        v = deep_iterable(member_validator, iterable_validator)
        expected_repr = (
            "<deep_iterable validator for"
            " {iterable_repr} iterables of {member_repr}>").format(
                iterable_repr=iterable_repr, member_repr=member_repr)

        assert expected_repr == repr(v)
Beispiel #24
0
def attrib_instance_list(type_: type) -> attr._make._CountingAttr:
    """
    Create a new attr attribute with validator for the given type_
    All attributes created are expected to be List of the given type_

    :param validator_type:
        Inform which type(s) the List should validate.
        When not defined the param type_ will be used the type to be validated.

    """
    # Config validator
    _validator = deep_iterable(
        member_validator=instance_of(type_),
        iterable_validator=instance_of(list),
    )
    metadata = {"type": "instance_list", "class_": type_}
    return attr.ib(
        default=attr.Factory(list),
        validator=_validator,
        type=List[type_],
        metadata=metadata,
    )
class _AwsKmsDiscoveryKeyring(Keyring):
    """AWS KMS discovery keyring that will attempt to decrypt any AWS KMS encrypted data key.

    This keyring should never be used directly.
    It should only ever be used internally by :class:`AwsKmsKeyring`.

    .. versionadded:: 1.5.0

    :param ClientSupplier client_supplier: Client supplier to use when asking for clients
    :param List[str] grant_tokens: AWS KMS grant tokens to include in requests (optional)
    """

    _client_supplier = attr.ib(validator=is_callable())
    _grant_tokens = attr.ib(
        default=attr.Factory(tuple),
        validator=(deep_iterable(member_validator=instance_of(six.string_types)), value_is_not_a_string),
    )

    def on_encrypt(self, encryption_materials):
        # type: (EncryptionMaterials) -> EncryptionMaterials
        return encryption_materials

    def on_decrypt(self, decryption_materials, encrypted_data_keys):
        # type: (DecryptionMaterials, Iterable[EncryptedDataKey]) -> DecryptionMaterials
        new_materials = decryption_materials

        for edk in encrypted_data_keys:
            if new_materials.data_encryption_key is not None:
                return new_materials

            if edk.key_provider.provider_id == KEY_NAMESPACE:
                new_materials = _try_aws_kms_decrypt(
                    client_supplier=self._client_supplier,
                    decryption_materials=new_materials,
                    grant_tokens=self._grant_tokens,
                    encrypted_data_key=edk,
                )

        return new_materials
class CryptoResult(object):
    """Result container for one-shot cryptographic API results.

    .. versionadded:: 1.5.0

    .. note::

        For backwards compatibility,
        this container also unpacks like a 2-member tuple.
        This allows for backwards compatibility with the previous outputs.

    :param bytes result: Binary results of the cryptographic operation
    :param MessageHeader header: Encrypted message metadata
    :param Tuple[KeyringTrace] keyring_trace: Keyring trace entries
    """

    result = attr.ib(validator=instance_of(bytes))
    header = attr.ib(validator=instance_of(MessageHeader))
    keyring_trace = attr.ib(validator=deep_iterable(
        member_validator=instance_of(KeyringTrace)))

    def __attrs_post_init__(self):
        """Construct the inner tuple for backwards compatibility."""
        self._legacy_container = (self.result, self.header)

    def __len__(self):
        """Emulate the inner tuple."""
        return self._legacy_container.__len__()

    def __iter__(self):
        """Emulate the inner tuple."""
        return self._legacy_container.__iter__()

    def __getitem__(self, key):
        """Emulate the inner tuple."""
        return self._legacy_container.__getitem__(key)
def attrib_instance_list(
    type_,
    *,
    validator_type: Optional[Tuple[Any,
                                   ...]] = None) -> attr._make._CountingAttr:
    """
    Create a new attr attribute with validator for the given type_
    All attributes created are expected to be List of the given type_

    :param validator_type:
        Inform which type(s) the List should validate.
        When not defined the param type_ will be used the type to be validated.

    """
    # Config validator
    _validator_type = validator_type or type_
    _validator = deep_iterable(
        member_validator=instance_of(_validator_type),
        iterable_validator=instance_of(list),
    )

    return attr.ib(default=attr.Factory(list),
                   validator=_validator,
                   type=List[type_])
Beispiel #28
0
class MdParserConfig:
    """Configuration options for the Markdown Parser.

    Note in the sphinx configuration these option names are prepended with ``myst_``
    """

    renderer: str = attr.ib(default="sphinx",
                            validator=in_(["sphinx", "html", "docutils"]))
    commonmark_only: bool = attr.ib(default=False, validator=instance_of(bool))
    dmath_allow_labels: bool = attr.ib(default=True,
                                       validator=instance_of(bool))
    dmath_allow_space: bool = attr.ib(default=True,
                                      validator=instance_of(bool))
    dmath_allow_digits: bool = attr.ib(default=True,
                                       validator=instance_of(bool))

    update_mathjax: bool = attr.ib(default=True, validator=instance_of(bool))

    # TODO remove deprecated _enable attributes after v0.13.0
    admonition_enable: bool = attr.ib(default=False,
                                      validator=instance_of(bool),
                                      repr=False)
    figure_enable: bool = attr.ib(default=False,
                                  validator=instance_of(bool),
                                  repr=False)
    dmath_enable: bool = attr.ib(default=False,
                                 validator=instance_of(bool),
                                 repr=False)
    amsmath_enable: bool = attr.ib(default=False,
                                   validator=instance_of(bool),
                                   repr=False)
    deflist_enable: bool = attr.ib(default=False,
                                   validator=instance_of(bool),
                                   repr=False)
    html_img_enable: bool = attr.ib(default=False,
                                    validator=instance_of(bool),
                                    repr=False)
    colon_fence_enable: bool = attr.ib(default=False,
                                       validator=instance_of(bool),
                                       repr=False)

    enable_extensions: Iterable[str] = attr.ib(factory=lambda: ["dollarmath"])

    @enable_extensions.validator
    def check_extensions(self, attribute, value):
        if not isinstance(value, Iterable):
            raise TypeError(f"myst_enable_extensions not iterable: {value}")
        diff = set(value).difference([
            "dollarmath",
            "amsmath",
            "deflist",
            "html_image",
            "colon_fence",
            "smartquotes",
            "replacements",
            "linkify",
            "substitution",
        ])
        if diff:
            raise ValueError(f"myst_enable_extensions not recognised: {diff}")

    disable_syntax: List[str] = attr.ib(
        factory=list,
        validator=deep_iterable(instance_of(str), instance_of((list, tuple))),
    )

    # see https://en.wikipedia.org/wiki/List_of_URI_schemes
    url_schemes: Optional[List[str]] = attr.ib(
        default=None,
        validator=optional(
            deep_iterable(instance_of(str), instance_of((list, tuple)))),
    )

    heading_anchors: Optional[int] = attr.ib(default=None,
                                             validator=optional(
                                                 in_([1, 2, 3, 4, 5, 6, 7])))

    substitutions: Dict[str, str] = attr.ib(
        factory=dict,
        validator=deep_mapping(instance_of(str), instance_of(
            (str, int, float)), instance_of(dict)),
        repr=lambda v: str(list(v)),
    )

    sub_delimiters: Tuple[str, str] = attr.ib(default=("{", "}"))

    @sub_delimiters.validator
    def check_sub_delimiters(self, attribute, value):
        if (not isinstance(value, (tuple, list))) or len(value) != 2:
            raise TypeError(
                f"myst_sub_delimiters is not a tuple of length 2: {value}")
        for delim in value:
            if (not isinstance(delim, str)) or len(delim) != 1:
                raise TypeError(
                    f"myst_sub_delimiters does not contain strings of length 1: {value}"
                )

    def as_dict(self, dict_factory=dict) -> dict:
        return attr.asdict(self, dict_factory=dict_factory)
Beispiel #29
0
class Param:
    """Specification of a parameter that is to be constrained."""

    name = attr.ib()
    _min: numeric = attr.ib(-np.inf)
    _max: numeric = attr.ib(np.inf)

    prior = attr.ib(
        kw_only=True,
        validator=vld.optional(vld.instance_of(stats.distributions.rv_frozen)),
    )

    fiducial = attr.ib(
        None,
        type=float,
        converter=cnv.optional(float),
        validator=vld.optional(vld.instance_of(float)),
        kw_only=True,
    )
    latex = attr.ib(kw_only=True)
    ref = attr.ib(kw_only=True)
    determines = attr.ib(
        converter=tuplify,
        kw_only=True,
        validator=vld.deep_iterable(vld.instance_of(str)),
    )
    transforms = attr.ib(converter=tuplify, kw_only=True)

    @latex.default
    def _ltx_default(self):
        return texify(self.name)

    @ref.default
    def _ref_default(self):
        return self.prior

    @prior.default
    def _prior_default(self) -> stats.distributions.rv_frozen | None:
        if np.isinf(self._min) or np.isinf(self._max):
            return None

        return stats.uniform(self._min, self._max - self._min)

    @determines.default
    def _determines_default(self):
        return (self.name,)

    @transforms.default
    def _transforms_default(self):
        return (None,) * len(self.determines)

    @transforms.validator
    def _transforms_validator(self, attribute, value):
        for val in value:
            if val is not None and not callable(val):
                raise TypeError("transforms must be a list of callables")

    @property
    def min(self) -> float:
        """The minimum boundary of the prior, helpful for constraints."""
        if self.prior is None:
            return self._min
        elif isinstance(self.prior, type(stats.uniform(0, 1))):
            return self.prior.support()[0]
        else:
            return -np.inf

    @property
    def max(self) -> float:
        """The maximum boundary of the prior, helpful for constraints."""
        if self.prior is None:
            return self._max
        elif isinstance(self.prior, type(stats.uniform(0, 1))):
            return self.prior.support()[1]
        else:
            return np.inf

    @cached_property
    def is_alias(self):
        return all(pm is None for pm in self.transforms)

    @cached_property
    def is_pure_alias(self):
        return self.is_alias and len(self.determines) == 1

    def transform(self, val):
        for pm in self.transforms:
            if pm is None:
                yield val
            else:
                yield pm(val)

    def generate_ref(self, n=1):
        if self.ref is None:
            raise ValueError("Must specify a valid function for ref to generate refs.")

        try:
            ref = self.ref.rvs(size=n)
        except AttributeError:
            try:
                ref = self.ref(size=n)
            except TypeError:
                raise TypeError(
                    f"parameter '{self.name}' does not have a valid value for ref"
                )

        if np.any(self.prior.pdf(ref) == 0):
            raise ValueError(
                f"param {self.name} produced a reference value outside its domain."
            )

        return ref

    def logprior(self, val):
        if self.prior is None:
            if self._min > val or self._max < val:
                return -np.inf
            else:
                return 0

        return self.prior.logpdf(val)

    def clone(self, **kwargs):
        return attr.evolve(self, **kwargs)

    def new(self, p: Parameter) -> Param:
        """Create a new :class:`Param`.

        Any missing info from this instance filled in by the given instance.
        """
        assert isinstance(p, Parameter)
        assert self.determines == (p.name,)

        if len(self.determines) > 1:
            raise ValueError("Cannot create new Param if it is not just an alias")

        default_range = (list(self.transform(p.min))[0], list(self.transform(p.max))[0])

        return Param(
            name=self.name,
            min=max(self._min, min(default_range)),
            max=min(self._max, max(default_range)),
            fiducial=self.fiducial if self.fiducial is not None else p.fiducial,
            latex=self.latex
            if (self.latex != self.name or self.name != p.name)
            else p.latex,
            ref=self.ref or attr.NOTHING,
            prior=self.prior or attr.NOTHING,
            determines=self.determines,
            transforms=self.transforms,
        )

    def __getstate__(self):
        """Obtain a simple input state of the class that can initialize it."""
        out = attr.asdict(self)

        if self.transforms == (None,):
            del out["transforms"]
        if self.ref is None:
            del out["ref"]
        if self.determines == (self.name,):
            del out["determines"]
        if self.latex == self.name:
            del out["latex"]

        return out

    def as_dict(self):
        """Simple representation of the class as a dict.

        No "name" is included in the dict.
        """
        out = self.__getstate__()
        del out["name"]
        return out
Beispiel #30
0
class ObjectRecognizer:
    """
    The ObjectRecognizer finds object matches in the scene pattern and adds a `ObjectSemanticNodePerceptionPredicate`
    which can be used to learn additional semantics which relate objects to other objects

    If applied to a dynamic situation, this will only recognize objects
    which are present in both the BEFORE and AFTER frames.
    """

    # Because static patterns must be applied to static perceptions
    # and dynamic patterns to dynamic situations,
    # we need to store our patterns both ways.
    _concepts_to_static_patterns: ImmutableDict[
        ObjectConcept, PerceptionGraphPattern] = attrib(
            validator=deep_mapping(instance_of(ObjectConcept),
                                   instance_of(PerceptionGraphPattern)),
            converter=_to_immutabledict,
        )
    _concepts_to_names: ImmutableDict[ObjectConcept, str] = attrib(
        validator=deep_mapping(instance_of(ObjectConcept), instance_of(str)),
        converter=_to_immutabledict,
    )

    # We derive these from the static patterns.
    _concepts_to_dynamic_patterns: ImmutableDict[
        ObjectConcept, PerceptionGraphPattern] = attrib(init=False)
    determiners: ImmutableSet[str] = attrib(converter=_to_immutableset,
                                            validator=deep_iterable(
                                                instance_of(str)))
    """
    This is a hack to handle determiners.
    See https://github.com/isi-vista/adam/issues/498
    """
    _concept_to_num_subobjects: ImmutableDict[Concept,
                                              int] = attrib(init=False)
    """
    Used for a performance optimization in match_objects.
    """
    _language_mode: LanguageMode = attrib(validator=instance_of(LanguageMode),
                                          kw_only=True)

    def __attrs_post_init__(self) -> None:
        non_lowercase_determiners = [
            determiner for determiner in self.determiners
            if determiner.lower() != determiner
        ]
        if non_lowercase_determiners:
            raise RuntimeError(
                f"All determiners must be specified in lowercase, but got "
                f"{non_lowercase_determiners}")

    @staticmethod
    def for_ontology_types(
        ontology_types: Iterable[OntologyNode],
        determiners: Iterable[str],
        ontology: Ontology,
        language_mode: LanguageMode,
        *,
        perception_generator:
        HighLevelSemanticsSituationToDevelopmentalPrimitivePerceptionGenerator,
    ) -> "ObjectRecognizer":
        ontology_types_to_concepts = {
            obj_type: ObjectConcept(obj_type.handle)
            for obj_type in ontology_types
        }

        return ObjectRecognizer(
            concepts_to_static_patterns=_sort_mapping_by_pattern_complexity(
                immutabledict((
                    concept,
                    PerceptionGraphPattern.from_ontology_node(
                        obj_type,
                        ontology,
                        perception_generator=perception_generator),
                ) for (obj_type,
                       concept) in ontology_types_to_concepts.items())),
            determiners=determiners,
            concepts_to_names={
                concept: obj_type.handle
                for obj_type, concept in ontology_types_to_concepts.items()
            },
            language_mode=language_mode,
        )

    def match_objects_old(
        self, perception_graph: PerceptionGraph
    ) -> PerceptionGraphFromObjectRecognizer:
        new_style_input = PerceptionSemanticAlignment(
            perception_graph=perception_graph, semantic_nodes=[])
        new_style_output = self.match_objects(new_style_input)
        return PerceptionGraphFromObjectRecognizer(
            perception_graph=new_style_output[0].perception_graph,
            description_to_matched_object_node=new_style_output[1],
        )

    def match_objects(
        self,
        perception_semantic_alignment: PerceptionSemanticAlignment,
        *,
        post_process: Callable[[PerceptionGraph, AbstractSet[SemanticNode]],
                               Tuple[PerceptionGraph,
                                     AbstractSet[SemanticNode]],
                               ] = default_post_process_enrichment,
    ) -> Tuple[PerceptionSemanticAlignment, Mapping[Tuple[str, ...],
                                                    ObjectSemanticNode]]:
        r"""
        Recognize known objects in a `PerceptionGraph`.

        The matched portion of the graph will be replaced with an `ObjectSemanticNode`\ s
        which will inherit all relationships of any nodes internal to the matched portion
        with any external nodes.

        This is useful as a pre-processing step
        before prepositional and verbal learning experiments.
        """

        # pylint: disable=global-statement,invalid-name
        global cumulative_millis_in_successful_matches_ms
        global cumulative_millis_in_failed_matches_ms

        object_nodes: List[Tuple[Tuple[str, ...], ObjectSemanticNode]] = []
        perception_graph = perception_semantic_alignment.perception_graph
        is_dynamic = perception_semantic_alignment.perception_graph.dynamic

        if is_dynamic:
            concepts_to_patterns = self._concepts_to_dynamic_patterns
        else:
            concepts_to_patterns = self._concepts_to_static_patterns

        # We special case handling the ground perception
        # Because we don't want to remove it from the graph, we just want to use it's
        # Object node as a recognized object. The situation "a box on the ground"
        # Prompted the need to recognize the ground
        graph_to_return = perception_graph
        for node in graph_to_return._graph.nodes:  # pylint:disable=protected-access
            if node == GROUND_PERCEPTION:
                matched_object_node = ObjectSemanticNode(GROUND_OBJECT_CONCEPT)
                if LanguageMode.ENGLISH == self._language_mode:
                    object_nodes.append(
                        ((f"{GROUND_OBJECT_CONCEPT.debug_string}", ),
                         matched_object_node))
                elif LanguageMode.CHINESE == self._language_mode:
                    object_nodes.append((("di4 myan4", ), matched_object_node))
                else:
                    raise RuntimeError("Invalid language_generator")
                # We construct a fake match which is only the ground perception node
                subgraph_of_root = subgraph(perception_graph.copy_as_digraph(),
                                            [node])
                pattern_match = PerceptionGraphPatternMatch(
                    matched_pattern=PerceptionGraphPattern(
                        graph=subgraph_of_root,
                        dynamic=perception_graph.dynamic),
                    graph_matched_against=perception_graph,
                    matched_sub_graph=PerceptionGraph(
                        graph=subgraph_of_root,
                        dynamic=perception_graph.dynamic),
                    pattern_node_to_matched_graph_node=immutabledict(),
                )
                graph_to_return = replace_match_with_object_graph_node(
                    matched_object_node, graph_to_return, pattern_match)

        candidate_object_subgraphs = extract_candidate_objects(
            perception_graph)

        for candidate_object_graph in candidate_object_subgraphs:
            num_object_nodes = candidate_object_graph.count_nodes_matching(
                lambda node: isinstance(node, ObjectPerception))

            for (concept, pattern) in concepts_to_patterns.items():
                # As an optimization, we count how many sub-object nodes
                # are in the graph and the pattern.
                # If they aren't the same, the match is impossible
                # and we can bail out early.
                if num_object_nodes != self._concept_to_num_subobjects[concept]:
                    continue

                with Timer(factor=1000) as t:
                    matcher = pattern.matcher(candidate_object_graph,
                                              match_mode=MatchMode.OBJECT)
                    pattern_match = first(
                        matcher.matches(use_lookahead_pruning=True), None)
                if pattern_match:
                    cumulative_millis_in_successful_matches_ms += t.elapsed
                    matched_object_node = ObjectSemanticNode(concept)

                    # We wrap the concept in a tuple because it could in theory be multiple
                    # tokens,
                    # even though currently it never is.
                    if self._language_mode == LanguageMode.ENGLISH:
                        object_nodes.append(
                            ((concept.debug_string, ), matched_object_node))
                    elif self._language_mode == LanguageMode.CHINESE:
                        if concept.debug_string == "me":
                            object_nodes.append(
                                (("wo3", ), matched_object_node))
                        elif concept.debug_string == "you":
                            object_nodes.append(
                                (("ni3", ), matched_object_node))
                        mappings = (
                            GAILA_PHASE_1_CHINESE_LEXICON.
                            _ontology_node_to_word  # pylint:disable=protected-access
                        )
                        for k, v in mappings.items():
                            if k.handle == concept.debug_string:
                                debug_string = str(v.base_form)
                                object_nodes.append(
                                    ((debug_string, ), matched_object_node))
                    graph_to_return = replace_match_with_object_graph_node(
                        matched_object_node, graph_to_return, pattern_match)
                    # We match each candidate objects against only one object type.
                    # See https://github.com/isi-vista/adam/issues/627
                    break
                else:
                    cumulative_millis_in_failed_matches_ms += t.elapsed
        if object_nodes:
            logging.info(
                "Object recognizer recognized: %s",
                [concept for (concept, _) in object_nodes],
            )
        logging.info(
            "object matching: ms in success: %s, ms in failed: %s",
            cumulative_millis_in_successful_matches_ms,
            cumulative_millis_in_failed_matches_ms,
        )
        semantic_object_nodes = immutableset(node
                                             for (_, node) in object_nodes)

        post_process_graph, post_process_nodes = post_process(
            graph_to_return, semantic_object_nodes)

        return (
            perception_semantic_alignment.
            copy_with_updated_graph_and_added_nodes(
                new_graph=post_process_graph, new_nodes=post_process_nodes),
            immutabledict(object_nodes),
        )

    def match_objects_with_language_old(
        self, language_aligned_perception: LanguageAlignedPerception
    ) -> LanguageAlignedPerception:
        if language_aligned_perception.node_to_language_span:
            raise RuntimeError(
                "Don't know how to handle a non-empty node-to-language-span")
        new_style_input = LanguagePerceptionSemanticAlignment(
            language_concept_alignment=LanguageConceptAlignment(
                language_aligned_perception.language,
                node_to_language_span=[]),
            perception_semantic_alignment=PerceptionSemanticAlignment(
                perception_graph=language_aligned_perception.perception_graph,
                semantic_nodes=[],
            ),
        )
        new_style_output = self.match_objects_with_language(new_style_input)
        return LanguageAlignedPerception(
            language=new_style_output.language_concept_alignment.language,
            perception_graph=new_style_output.perception_semantic_alignment.
            perception_graph,
            node_to_language_span=new_style_output.language_concept_alignment.
            node_to_language_span,
        )

    def match_objects_with_language(
        self,
        language_perception_semantic_alignment:
        LanguagePerceptionSemanticAlignment,
        *,
        post_process: Callable[[PerceptionGraph, AbstractSet[SemanticNode]],
                               Tuple[PerceptionGraph,
                                     AbstractSet[SemanticNode]],
                               ] = default_post_process_enrichment,
    ) -> LanguagePerceptionSemanticAlignment:
        """
        Recognize known objects in a `LanguagePerceptionSemanticAlignment`.

        For each node matched, this will identify the relevant portion of the linguistic input
        and record the correspondence.

        The matched portion of the graph will be replaced with an `ObjectSemanticNode`
        which will inherit all relationships of any nodes internal to the matched portion
        with any external nodes.

        This is useful as a pre-processing step
        before prepositional and verbal learning experiments.
        """
        if (language_perception_semantic_alignment.
                perception_semantic_alignment.semantic_nodes):
            raise RuntimeError(
                "We assume ObjectRecognizer is run first, with no previous "
                "alignments")

        (
            post_match_perception_semantic_alignment,
            tokens_to_object_nodes,
        ) = self.match_objects(
            language_perception_semantic_alignment.
            perception_semantic_alignment,
            post_process=post_process,
        )
        return LanguagePerceptionSemanticAlignment(
            language_concept_alignment=language_perception_semantic_alignment.
            language_concept_alignment.copy_with_added_token_alignments(
                self._align_objects_to_tokens(
                    tokens_to_object_nodes,
                    language_perception_semantic_alignment.
                    language_concept_alignment.language,
                )),
            perception_semantic_alignment=
            post_match_perception_semantic_alignment,
        )

    def _align_objects_to_tokens(
        self,
        description_to_object_node: Mapping[Tuple[str, ...],
                                            ObjectSemanticNode],
        language: LinguisticDescription,
    ) -> Mapping[ObjectSemanticNode, Span]:
        result: List[Tuple[ObjectSemanticNode, Span]] = []

        # We want to ban the same token index from being aligned twice.
        matched_token_indices: Set[int] = set()

        for (description_tuple,
             object_node) in description_to_object_node.items():
            if len(description_tuple) != 1:
                raise RuntimeError(
                    f"Multi-token descriptions are not yet supported:"
                    f"{description_tuple}")
            description = description_tuple[0]
            try:
                end_index_inclusive = language.index(description)
            except ValueError:
                # A scene might contain things which are not referred to by the associated language.
                continue

            start_index = end_index_inclusive
            # This is a somewhat language-dependent hack to gobble up preceding determiners.
            # See https://github.com/isi-vista/adam/issues/498 .
            if end_index_inclusive > 0:
                possible_determiner_index = end_index_inclusive - 1
                if language[possible_determiner_index].lower(
                ) in self.determiners:
                    start_index = possible_determiner_index

            # We record what tokens were covered so we can block the same tokens being used twice.
            for included_token_index in range(start_index,
                                              end_index_inclusive + 1):
                if included_token_index in matched_token_indices:
                    raise RuntimeError(
                        "We do not currently support the same object "
                        "being mentioned twice in a sentence.")
                matched_token_indices.add(included_token_index)

            result.append((
                object_node,
                language.span(start_index,
                              end_index_exclusive=end_index_inclusive + 1),
            ))
        return immutabledict(result)

    @_concepts_to_dynamic_patterns.default
    def _init_concepts_to_dynamic_patterns(
            self) -> ImmutableDict[ObjectConcept, PerceptionGraphPattern]:
        return immutabledict(
            (concept, static_pattern.copy_with_temporal_scopes(ENTIRE_SCENE))
            for (concept,
                 static_pattern) in self._concepts_to_static_patterns.items())

    @_concept_to_num_subobjects.default
    def _init_patterns_to_num_subobjects(
            self) -> ImmutableDict[ObjectConcept, int]:
        return immutabledict((
            concept,
            pattern.count_nodes_matching(
                lambda node: isinstance(node, AnyObjectPerception)),
        ) for (concept, pattern) in self._concepts_to_static_patterns.items())