Ejemplo n.º 1
0
class Config(_ConfigStructure):
    """PipeFormer project configuration.

    :param name: Project name
    :param description: Project description
    :param generate_cmk: Should a custom CMK be generated? (reserved for later use: must always be ``True``)
    :param pipeline: Mapping of stage names to pipeline stages
    :param inputs: Mapping of input names to loaded inputs
    """

    name: str = attr.ib(validator=instance_of(str))
    description: str = attr.ib(validator=instance_of(str))
    generate_cmk: bool = attr.ib(validator=instance_of(bool))
    pipeline: Dict[str, PipelineStage] = attr.ib(
        validator=deep_mapping(key_validator=instance_of(str), value_validator=instance_of(PipelineStage))
    )
    inputs: Dict[str, Input] = attr.ib(
        validator=optional(deep_mapping(key_validator=instance_of(str), value_validator=instance_of(Input)))
    )

    @generate_cmk.validator
    def _check_generate_cmk(self, attribute, value):  # pylint: disable=unused-argument,,no-self-use
        """Validate that the ``generate_cmk`` value is always ``True``."""
        if not value:
            raise ValueError(
                "Use of AWS-managed CMKs is not supported. Must use customer-managed CMK (generate-cmk: true)."
            )

    @classmethod
    def from_dict(cls, kwargs: Dict):
        """Load a PipeFormer config from a dictionary parsed from a PipeFormer config file.

        :param kwargs: Parsed config file dictionary
        :return: Loaded PipeFormer config
        """
        loaded = kwargs.copy()

        if "inputs" in loaded:
            loaded["inputs"] = {
                key: Input.from_dict(dict(name=key, **value)) for key, value in kwargs["inputs"].items()
            }

        loaded["pipeline"] = {
            key: PipelineStage(name=key, actions=[PipelineAction.from_dict(value) for value in actions])
            for key, actions in kwargs["pipeline"].items()
        }

        return cls(**cls._clean_kwargs(loaded))

    @classmethod
    def from_file(cls, filename: str):
        """Load a PipeFormer config from an existing file.

        :param filename: Existing filename
        :return: Loaded PipeFormer config
        """
        with open(filename, "rb") as config_file:
            raw_parsed = yaml.safe_load(config_file)

        return cls.from_dict(raw_parsed)
Ejemplo n.º 2
0
    def test_noncallable_validators(
        self, key_validator, value_validator, mapping_validator
    ):
        """
        Raise :class:`TypeError` if any validators are not callable.
        """
        with pytest.raises(TypeError) as e:
            deep_mapping(key_validator, value_validator, mapping_validator)

        e.match(r"\w* must be callable")
Ejemplo n.º 3
0
class InventoryModel:
    goods: Dict[str, float] = attrib(validator=validators.deep_mapping(
        key_validator=validators.and_(validators.instance_of(str),
                                      non_empty_str),
        value_validator=validators.and_(
            validators.instance_of(float),
            positive_int  # type: ignore
        ),
        mapping_validator=validators.instance_of(dict),
    ))
    discounts: Dict[str, DiscountModel] = attrib(
        validator=validators.deep_mapping(
            key_validator=validators.instance_of(str),
            value_validator=validators.instance_of(DiscountModel),
            mapping_validator=validators.instance_of(dict),
        ),
        converter=defaultdict_of_classes(  # type: ignore
            DiscountModel, lambda: NO_DISCOUNT),
    )
    multibuy: Dict[str, MultibuyModel] = attrib(
        validator=validators.deep_mapping(
            key_validator=validators.instance_of(str),
            value_validator=validators.instance_of(MultibuyModel),
            mapping_validator=validators.instance_of(dict),
        ),
        converter=defaultdict_of_classes(  # type: ignore
            MultibuyModel, lambda: NO_MULTIBUY),
    )

    def __attrs_post_init__(self):
        goods = set(self.goods.keys())

        # checks if not discounting unknown items
        for attr in ["discounts", "multibuy"]:
            unknown = set(getattr(self, attr).keys()).difference(goods)
            if len(unknown) > 0:
                raise ValueError(
                    f"Unknown goods as key in '{attr}': {unknown}")

        unknown = {
            multibuy_item.discounts_goods
            for multibuy_item in self.multibuy.values()
        }.difference(goods)

        if len(unknown) > 0:
            raise ValueError(
                f"Unknown goods as value for 'multibuy.*.discounted_goods': {unknown}"
            )
Ejemplo n.º 4
0
    def test_noncallable_validators(self, key_validator, value_validator,
                                    mapping_validator):
        """
        Raise `TypeError` if any validators are not callable.
        """
        with pytest.raises(TypeError) as e:
            deep_mapping(key_validator, value_validator, mapping_validator)

        value = 42
        message = "must be callable (got {value} that is a {type_}).".format(
            value=value, type_=value.__class__)

        assert message in e.value.args[0]
        assert value == e.value.args[1]
        assert message in e.value.msg
        assert value == e.value.value
class EncryptionMaterialsRequest(object):
    """Request object to provide to a crypto material manager's `get_encryption_materials` method.

    .. versionadded:: 1.3.0

    .. warning::
        If plaintext_rostream seek position is modified, it must be returned before leaving method.

    :param dict encryption_context: Encryption context passed to underlying master key provider and master keys
    :param int frame_length: Frame length to be used while encrypting stream
    :param plaintext_rostream: Source plaintext read-only stream (optional)
    :type plaintext_rostream: aws_encryption_sdk.internal.utils.streams.ROStream
    :param algorithm: Algorithm passed to underlying master key provider and master keys (optional)
    :type algorithm: aws_encryption_sdk.identifiers.Algorithm
    :param int plaintext_length: Length of source plaintext (optional)
    """

    encryption_context = attr.ib(
        validator=deep_mapping(
            key_validator=instance_of(six.string_types), value_validator=instance_of(six.string_types)
        )
    )
    frame_length = attr.ib(validator=instance_of(six.integer_types))
    plaintext_rostream = attr.ib(default=None, validator=optional(instance_of(ROStream)))
    algorithm = attr.ib(default=None, validator=optional(instance_of(Algorithm)))
    plaintext_length = attr.ib(default=None, validator=optional(instance_of(six.integer_types)))
Ejemplo n.º 6
0
class JobDescription:
    # The job driver language, this field determines how to start the
    # driver. The value is one of the names of enum Language defined in
    # common.proto, e.g. PYTHON
    language = attr.ib(type=str, validator=in_(common_pb2.Language.keys()))
    # The runtime_env (RuntimeEnvDict) for the job config.
    runtime_env = attr.ib(type=RuntimeEnv,
                          converter=lambda kw: RuntimeEnv(**kw))
    # The entry to start the driver.
    # PYTHON:
    #   - The basename of driver filename without extension in the job
    #   package archive.
    # JAVA:
    #   - The driver class full name in the job package archive.
    driver_entry = attr.ib(type=str, validator=instance_of(str))
    # The driver arguments in list.
    # PYTHON:
    #   -  The arguments to pass to the main() function in driver entry.
    #   e.g. [1, False, 3.14, "abc"]
    # JAVA:
    #   - The arguments to pass to the driver command line.
    #   e.g. ["-custom-arg", "abc"]
    driver_args = attr.ib(type=list, validator=instance_of(list), default=[])
    # The environment vars to pass to job config, type of keys should be str.
    env = attr.ib(type=dict,
                  validator=deep_mapping(key_validator=instance_of(str),
                                         value_validator=any_(),
                                         mapping_validator=instance_of(dict)),
                  default={})
class MessageHeader(object):
    # pylint: disable=too-many-instance-attributes
    """Deserialized message header object.

    :param SerializationVersion version: Message format version, per spec
    :param ObjectType type: Message content type, per spec
    :param AlgorithmSuite algorithm: Algorithm to use for encryption
    :param bytes message_id: Message ID
    :param Dict[str,str] encryption_context: Dictionary defining encryption context
    :param Sequence[EncryptedDataKey] encrypted_data_keys: Encrypted data keys
    :param ContentType content_type: Message content framing type (framed/non-framed)
    :param int content_aad_length: empty
    :param int header_iv_length: Bytes in Initialization Vector value found in header
    :param int frame_length: Length of message frame in bytes
    """

    version = attr.ib(hash=True, validator=instance_of(SerializationVersion))
    type = attr.ib(hash=True, validator=instance_of(ObjectType))
    algorithm = attr.ib(hash=True, validator=instance_of(Algorithm))
    message_id = attr.ib(hash=True, validator=instance_of(bytes))
    encryption_context = attr.ib(
        hash=True,
        validator=deep_mapping(key_validator=instance_of(six.string_types),
                               value_validator=instance_of(six.string_types)),
    )
    encrypted_data_keys = attr.ib(
        hash=True,
        validator=deep_iterable(
            member_validator=instance_of(EncryptedDataKey)))
    content_type = attr.ib(hash=True, validator=instance_of(ContentType))
    content_aad_length = attr.ib(hash=True,
                                 validator=instance_of(six.integer_types))
    header_iv_length = attr.ib(hash=True,
                               validator=instance_of(six.integer_types))
    frame_length = attr.ib(hash=True, validator=instance_of(six.integer_types))
Ejemplo n.º 8
0
class ActionSemanticNode(SemanticNode):
    concept: ActionConcept = attrib(validator=instance_of(ActionConcept))
    slot_fillings: ImmutableDict[SyntaxSemanticsVariable,
                                 "ObjectSemanticNode"] = attrib(
                                     converter=_to_immutabledict,
                                     validator=deep_mapping(
                                         instance_of(SyntaxSemanticsVariable),
                                         instance_of(ObjectSemanticNode)),
                                 )
Ejemplo n.º 9
0
 def test_success(self):
     """
     If both the key and value validators succeed, nothing happens.
     """
     key_validator = instance_of(str)
     value_validator = instance_of(int)
     v = deep_mapping(key_validator, value_validator)
     a = simple_attr("test")
     v(None, a, {"a": 6, "b": 7})
Ejemplo n.º 10
0
class PerceptionSemanticAlignment:
    """
    Represents an alignment between a perception graph and a set of semantic nodes representing
    concepts.

    This is used to represent intermediate semantic data passed between new-style learners when
    describing a perception.
    """

    perception_graph: PerceptionGraph = attrib(
        validator=instance_of(PerceptionGraph))
    semantic_nodes: ImmutableSet[SemanticNode] = attrib(
        converter=_to_immutableset,
        validator=deep_iterable(instance_of(SemanticNode)))
    functional_concept_to_object_concept: ImmutableDict[
        FunctionalObjectConcept, ObjectConcept] = attrib(
            converter=_to_immutabledict,
            validator=deep_mapping(instance_of(FunctionalObjectConcept),
                                   instance_of(ObjectConcept)),
            default=immutabledict(),
        )

    @staticmethod
    def create_unaligned(
            perception_graph: PerceptionGraph
    ) -> "PerceptionSemanticAlignment":
        return PerceptionSemanticAlignment(perception_graph, [])

    def copy_with_updated_graph_and_added_nodes(
            self, *, new_graph: PerceptionGraph,
            new_nodes: Iterable[SemanticNode]
    ) -> "PerceptionSemanticAlignment":
        if new_graph is self.perception_graph and not new_nodes:
            return self
        else:
            return PerceptionSemanticAlignment(
                perception_graph=new_graph,
                semantic_nodes=chain(self.semantic_nodes, new_nodes),
            )

    def copy_with_mapping(
        self, *, mapping: Mapping[FunctionalObjectConcept, ObjectConcept]
    ) -> "PerceptionSemanticAlignment":
        return PerceptionSemanticAlignment(
            perception_graph=self.perception_graph,
            semantic_nodes=self.semantic_nodes,
            functional_concept_to_object_concept=mapping,
        )

    def __attrs_post_init__(self) -> None:
        for node in self.perception_graph._graph:  # pylint:disable=protected-access
            if isinstance(node, SemanticNode):
                if node not in self.semantic_nodes:
                    raise RuntimeError(
                        "All semantic nodes appearing in the perception graph must "
                        "also be in semantic_nodes")
Ejemplo n.º 11
0
 def test_fail_invalid_member(self):
     """
     Raise key validator error if an invalid member value is found.
     """
     key_validator = instance_of(str)
     value_validator = instance_of(int)
     v = deep_mapping(key_validator, value_validator)
     a = simple_attr("test")
     with pytest.raises(TypeError):
         v(None, a, {"a": "6", "b": 7})
Ejemplo n.º 12
0
 def test_fail_invalid_mapping(self):
     """
     Raise :class:`TypeError` if mapping validator fails.
     """
     key_validator = instance_of(str)
     value_validator = instance_of(int)
     mapping_validator = instance_of(dict)
     v = deep_mapping(key_validator, value_validator, mapping_validator)
     a = simple_attr("test")
     with pytest.raises(TypeError):
         v(None, a, None)
Ejemplo n.º 13
0
 def test_repr(self):
     """
     Returned validator has a useful `__repr__`.
     """
     key_validator = instance_of(str)
     key_repr = "<instance_of validator for type <class 'str'>>"
     value_validator = instance_of(int)
     value_repr = "<instance_of validator for type <class 'int'>>"
     v = deep_mapping(key_validator, value_validator)
     expected_repr = ("<deep_mapping validator for objects mapping "
                      "{key_repr} to {value_repr}>").format(
                          key_repr=key_repr, value_repr=value_repr)
     assert expected_repr == repr(v)
Ejemplo n.º 14
0
class DeploymentFile:
    Deployments = attr.ib(
        default=attr.Factory(partial(defaultdict, Deployment)),
        validator=deep_mapping(key_validator=instance_of(str),
                               value_validator=instance_of(Deployment)),
    )

    @classmethod
    def from_dict(cls, kwargs):
        return cls(
            Deployments={
                region: Deployment(**sub_args)
                for region, sub_args in kwargs.get("Deployments", {}).items()
            })
Ejemplo n.º 15
0
class Pipeline:
    """Container to hold all templates for a single PipeFormer pipeline.

    :param template: CodePipeline stack template
    :param codebuild: Mapping of stage names to corresponding CodeBuild templates
    """

    template: Template = attr.ib(validator=instance_of(Template))
    stage_templates: Dict[str, Template] = attr.ib(
        validator=deep_mapping(
            key_validator=instance_of(str),
            value_validator=instance_of(Template),
            mapping_validator=instance_of(OrderedDict),
        )
    )
Ejemplo n.º 16
0
def dict_of(type_: type) -> Callable:
    """
    An attr validator that performs validation of dictionary values.

    :param type_: The type to check for, can be a type or tuple of types

    :raises TypeError:
        raises a `TypeError` if the initializer is called with a wrong type for this particular attribute

    :return: An attr validator that performs validation of dictionary values.
    """
    return deep_mapping(
        key_validator=instance_of(str),
        value_validator=instance_of(type_),
        mapping_validator=instance_of(dict),
    )
Ejemplo n.º 17
0
class SurfaceTemplateBoundToSemanticNodes:
    """
    A surface template together with a mapping from its slots to particular semantic roles.

    This is used to specify what the thing we are trying to learn the meaning of in
    a template learner is.  For example, "what does 'X eats Y' mean, given that
    we know X is this thing and Y is that other thing in this particular situation.
    """

    surface_template: SurfaceTemplate = attrib(
        validator=instance_of(SurfaceTemplate))
    slot_to_semantic_node: ImmutableDict[
        SyntaxSemanticsVariable, ObjectSemanticNode] = attrib(
            converter=_to_immutabledict,
            validator=deep_mapping(instance_of(SyntaxSemanticsVariable),
                                   instance_of(ObjectSemanticNode)),
        )
Ejemplo n.º 18
0
class DecryptionMaterialsRequest(object):
    """Request object to provide to a crypto material manager's `decrypt_materials` method.

    .. versionadded:: 1.3.0

    :param algorithm: Algorithm to provide to master keys for underlying decrypt requests
    :type algorithm: aws_encryption_sdk.identifiers.Algorithm
    :param encrypted_data_keys: Set of encrypted data keys
    :type encrypted_data_keys: set of `aws_encryption_sdk.structures.EncryptedDataKey`
    :param dict encryption_context: Encryption context to provide to master keys for underlying decrypt requests
    """

    algorithm = attr.ib(validator=instance_of(Algorithm))
    encrypted_data_keys = attr.ib(validator=deep_iterable(member_validator=instance_of(EncryptedDataKey)))
    encryption_context = attr.ib(
        validator=deep_mapping(
            key_validator=instance_of(six.string_types), value_validator=instance_of(six.string_types)
        )
    )
Ejemplo n.º 19
0
class ProjectTemplates:
    """Container to hold all templates for a PipeFormer project.

    :param core: Core stack template
    :param inputs: Inputs stack template
    :param iam: IAM stack template
    :param pipeline: CodePipeline stack template
    :param codebuild: Mapping of stage names to corresponding CodeBuild templates
    """

    core: Template = attr.ib(validator=instance_of(Template))
    inputs: Template = attr.ib(validator=instance_of(Template))
    iam: Template = attr.ib(validator=instance_of(Template))
    pipeline: Template = attr.ib(validator=instance_of(Template))
    codebuild: Dict[str, Template] = attr.ib(
        validator=deep_mapping(
            key_validator=instance_of(str),
            value_validator=instance_of(Template),
            mapping_validator=instance_of(OrderedDict),
        )
    )
Ejemplo n.º 20
0
class CryptographicMaterials(object):
    """Cryptographic materials core.

    .. versionadded:: 1.5.0

    :param Algorithm algorithm: Algorithm to use for encrypting message
    :param dict encryption_context: Encryption context tied to `encrypted_data_keys`
    :param RawDataKey data_encryption_key: Plaintext data key to use for encrypting message
    :param keyring_trace: Any KeyRing trace entries
    :type keyring_trace: list of :class:`KeyringTrace`
    """

    algorithm = attr.ib(validator=optional(instance_of(Algorithm)))
    encryption_context = attr.ib(
        validator=optional(
            deep_mapping(key_validator=instance_of(six.string_types), value_validator=instance_of(six.string_types))
        )
    )
    data_encryption_key = attr.ib(
        default=None, validator=optional(instance_of(RawDataKey)), converter=_data_key_to_raw_data_key
    )
    _keyring_trace = attr.ib(
        default=attr.Factory(list), validator=optional(deep_iterable(member_validator=instance_of(KeyringTrace)))
    )
    _initialized = False

    def __attrs_post_init__(self):
        """Freeze attributes after initialization."""
        self._initialized = True

    def __setattr__(self, key, value):
        # type: (str, Any) -> None
        """Do not allow attributes to be changed once an instance is initialized."""
        if self._initialized:
            raise AttributeError("can't set attribute")

        self._setattr(key, value)

    def _setattr(self, key, value):
        # type: (str, Any) -> None
        """Special __setattr__ to avoid having to perform multi-level super calls."""
        super(CryptographicMaterials, self).__setattr__(key, value)

    def _validate_data_encryption_key(self, data_encryption_key, keyring_trace, required_flags):
        # type: (Union[DataKey, RawDataKey], KeyringTrace, Iterable[KeyringTraceFlag]) -> None
        """Validate that the provided data encryption key and keyring trace match for each other and the materials.

        .. versionadded:: 1.5.0

        :param RawDataKey data_encryption_key: Data encryption key
        :param KeyringTrace keyring_trace: Keyring trace corresponding to data_encryption_key
        :param required_flags: Iterable of required flags
        :type required_flags: iterable of :class:`KeyringTraceFlag`
        :raises AttributeError: if data encryption key is already set
        :raises InvalidKeyringTraceError: if keyring trace does not match decrypt action
        :raises InvalidKeyringTraceError: if keyring trace does not match data key provider
        :raises InvalidDataKeyError: if data key length does not match algorithm suite
        """
        if self.data_encryption_key is not None:
            raise AttributeError("Data encryption key is already set.")

        for flag in required_flags:
            if flag not in keyring_trace.flags:
                raise InvalidKeyringTraceError("Keyring flags do not match action.")

        if keyring_trace.wrapping_key != data_encryption_key.key_provider:
            raise InvalidKeyringTraceError("Keyring trace does not match data key provider.")

        if len(data_encryption_key.data_key) != self.algorithm.kdf_input_len:
            raise InvalidDataKeyError(
                "Invalid data key length {actual} must be {expected}.".format(
                    actual=len(data_encryption_key.data_key), expected=self.algorithm.kdf_input_len
                )
            )

    def _with_data_encryption_key(self, data_encryption_key, keyring_trace, required_flags):
        # type: (Union[DataKey, RawDataKey], KeyringTrace, Iterable[KeyringTraceFlag]) -> CryptographicMaterials
        """Get new cryptographic materials that include this data encryption key.

        .. versionadded:: 1.5.0

        :param RawDataKey data_encryption_key: Data encryption key
        :param KeyringTrace keyring_trace: Trace of actions that a keyring performed
          while getting this data encryption key
        :param required_flags: Iterable of required flags
        :type required_flags: iterable of :class:`KeyringTraceFlag`
        :raises AttributeError: if data encryption key is already set
        :raises InvalidKeyringTraceError: if keyring trace does not match required actions
        :raises InvalidKeyringTraceError: if keyring trace does not match data key provider
        :raises InvalidDataKeyError: if data key length does not match algorithm suite
        """
        self._validate_data_encryption_key(
            data_encryption_key=data_encryption_key, keyring_trace=keyring_trace, required_flags=required_flags
        )

        new_materials = copy.copy(self)

        data_key = _data_key_to_raw_data_key(data_key=data_encryption_key)
        new_materials._setattr(  # simplify access to copies pylint: disable=protected-access
            "data_encryption_key", data_key
        )
        new_materials._keyring_trace.append(keyring_trace)  # simplify access to copies pylint: disable=protected-access

        return new_materials

    @property
    def keyring_trace(self):
        # type: () -> Tuple[KeyringTrace]
        """Return a read-only version of the keyring trace.

        :rtype: tuple
        """
        return tuple(self._keyring_trace)
Ejemplo n.º 21
0
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Internal data structures."""
from collections import OrderedDict
from typing import Dict, Iterable, Optional, Set

import attr
import oyaml as yaml
from attr.validators import deep_iterable, deep_mapping, instance_of, optional
from troposphere import Ref, Sub, Template, cloudformation, secretsmanager, ssm

from .util import reference_name, resource_name

__all__ = ("Config", "PipelineStage", "PipelineAction", "Input", "ProjectTemplates", "WaitConditionStack", "Pipeline")
_STRING_STRING_MAP = deep_mapping(key_validator=instance_of(str), value_validator=instance_of(str))


def _resolve_parameter(name: Ref, version: str) -> Sub:
    """Build a CloudFormation dynamic reference string structure that resolves a SSM Parameter.

    :param name: Parameter name
    :param version: Parameter version
    :return: Dynamic reference
    """
    return Sub(f"{{{{resolve:ssm:${{name}}:{version}}}}}", {"name": name})


def _resolve_secret(arn: Ref) -> Sub:
    """Build a CloudFormation dynamic reference string structure that resolves a Secrets Manager secret.
Ejemplo n.º 22
0
class LearnerSemantics:
    """
    Represent's the learner's semantic (rather than perceptual) understanding of a situation.
    The learner is assumed to view the situation as a collection of *objects* which possess
    *attributes*, have *relations* to one another, and serve as the arguments of *actions*.
    """

    objects: ImmutableSet[ObjectSemanticNode] = attrib(
        converter=_to_immutableset)
    attributes: ImmutableSet[AttributeSemanticNode] = attrib(
        converter=_to_immutableset)
    relations: ImmutableSet[RelationSemanticNode] = attrib(
        converter=_to_immutableset)
    actions: ImmutableSet[ActionSemanticNode] = attrib(
        converter=_to_immutableset)

    functional_concept_to_object_concept: ImmutableDict[
        FunctionalObjectConcept, ObjectConcept] = attrib(
            converter=_to_immutabledict,
            validator=deep_mapping(instance_of(FunctionalObjectConcept),
                                   instance_of(ObjectConcept)),
            default=immutabledict(),
        )

    objects_to_attributes: ImmutableSetMultiDict[
        ObjectSemanticNode, AttributeSemanticNode] = attrib(init=False)
    objects_to_relation_in_slot1: ImmutableSetMultiDict[
        ObjectSemanticNode, RelationSemanticNode] = attrib(init=False)
    objects_to_actions: ImmutableSetMultiDict[ObjectSemanticNode,
                                              ActionSemanticNode] = attrib(
                                                  init=False)

    @staticmethod
    def from_nodes(
        semantic_nodes: Iterable[SemanticNode],
        *,
        concept_map: ImmutableDict[FunctionalObjectConcept,
                                   ObjectConcept] = immutabledict(),
    ) -> "LearnerSemantics":
        return LearnerSemantics(
            objects=[
                node for node in semantic_nodes
                if isinstance(node, ObjectSemanticNode)
            ],
            attributes=[
                node for node in semantic_nodes
                if isinstance(node, AttributeSemanticNode)
            ],
            relations=[
                node for node in semantic_nodes
                if isinstance(node, RelationSemanticNode)
            ],
            actions=[
                node for node in semantic_nodes
                if isinstance(node, ActionSemanticNode)
            ],
            functional_concept_to_object_concept=concept_map,
        )

    @objects_to_attributes.default
    def _init_objects_to_attributes(
        self
    ) -> ImmutableSetMultiDict[ObjectSemanticNode, AttributeSemanticNode]:
        return immutablesetmultidict(
            (one(attribute.slot_fillings.values()), attribute)
            for attribute in self.attributes)

    @objects_to_relation_in_slot1.default
    def _init_objects_to_relations(
        self
    ) -> ImmutableSetMultiDict[ObjectSemanticNode, AttributeSemanticNode]:
        return immutablesetmultidict(
            flatten([(slot_filler, relation)
                     for slot_filler in relation.slot_fillings.values()]
                    for relation in self.relations))

    @objects_to_actions.default
    def _init_objects_to_actions(
        self
    ) -> ImmutableSetMultiDict[ObjectSemanticNode, AttributeSemanticNode]:
        return immutablesetmultidict(
            flatten([(slot_filler, action)
                     for slot_filler in action.slot_fillings.values()]
                    for action in self.actions))
Ejemplo n.º 23
0
class InputResolver:
    """Wraps another structure and injects input references whenever a value is found that contains an input tag.

    As strings are read from the contents of the wrapped structure,
    they are expanded as necessary to CloudFormation dynamic references that will resolve the needed input values.

    Along the way, the referenced inputs are collected and can later be found in ``required_inputs``.
    This can be used to determine what inputs are required as parameters for a given CloudFormation template.

    :param wrapped: Wrapped structure
    :param inputs: Map of input names to :class:`Input` structures
    :param required_inputs: Known required input (optional)
    """

    _wrapped = attr.ib()
    _inputs = attr.ib(validator=deep_mapping(key_validator=instance_of(str), value_validator=instance_of(Input)))
    required_inputs = attr.ib(default=attr.Factory(set))

    @_wrapped.validator
    def _validate_wrapped(self, attribute, value):  # pylint: disable=unused-argument,no-self-use
        """Validate characteristics about the wrapped object.
        Used by attrs as the validator for the ``_wrapped`` attribute.
        """
        if isinstance(value, InputResolver):
            raise TypeError(f"{InputResolver!r} cannot wrap itself.")

        for reserved in ("required_inputs",):
            if hasattr(value, reserved):
                raise TypeError(f'Wrapped object must not have "{reserved}" attribute.')

    def __attrs_post_init__(self):
        """Enable otherwise hidden wrapped methods if those methods are found on the wrapped object."""
        for method in ("get", "keys", "values", "items"):
            if hasattr(self._wrapped, method):
                setattr(self, method, getattr(self, f"_{method}"))

    def __expand_values(self, value: str) -> Iterable[str]:
        """Expand a string into a prefix, input reference, and suffix."""
        prefix, name, suffix = _value_to_triplet(value, *_INPUT_TAG)

        input_definition = self._inputs[name]
        reference = input_definition.dynamic_reference()

        self.required_inputs.add(name)
        return prefix, reference, suffix

    def __convert_value(self, value) -> Union[_PrimitiveTypes, "InputResolver", str, Join]:
        """Convert a value from the wrapped object to a value that can insert input resolutions."""
        if isinstance(value, _PRIMITIVE_TYPES):
            return value

        if not isinstance(value, str):
            return InputResolver(wrapped=value, inputs=self._inputs, required_inputs=self.required_inputs)

        if not _tag_in_string(value, *_INPUT_TAG):
            return value

        return Join("", self.__expand_values(value))

    def __len__(self):
        """Passthrough length from wrapped."""
        return len(self._wrapped)

    def __eq__(self, other) -> bool:
        """Passthrough eq from wrapped."""
        if isinstance(other, InputResolver):
            return self._wrapped.__eq__(other._wrapped)  # pylint: disable=protected-access
        return self._wrapped.__eq__(other)

    def __lt__(self, other) -> bool:
        """Passthrough lt from wrapped."""
        if isinstance(other, InputResolver):
            return self._wrapped.__lt__(other._wrapped)  # pylint: disable=protected-access
        return self._wrapped.__lt__(other)

    def __gt__(self, other) -> bool:
        """Passthrough gt from wrapped."""
        if isinstance(other, InputResolver):
            return self._wrapped.__gt__(other._wrapped)  # pylint: disable=protected-access
        return self._wrapped.__gt__(other)

    def __le__(self, other) -> bool:
        """Passthrough le from wrapped."""
        if isinstance(other, InputResolver):
            return self._wrapped.__le__(other._wrapped)  # pylint: disable=protected-access
        return self._wrapped.__le__(other)

    def __ge__(self, other) -> bool:
        """Passthrough ge from wrapped."""
        if isinstance(other, InputResolver):
            return self._wrapped.__ge__(other._wrapped)  # pylint: disable=protected-access
        return self._wrapped.__ge__(other)

    def __str__(self) -> str:
        """Passthrough str from wrapped."""
        # TODO: Do we need to convert this?
        return self._wrapped.__str__()

    def __getattr__(self, name):
        """Get an attribute from wrapped and convert it."""
        return self.__convert_value(getattr(self._wrapped, name))

    def __call__(self, *args, **kwargs):
        """Call wrapped and convert the result."""
        return self.__convert_value(self._wrapped(*args, **kwargs))

    def __getitem__(self, key):
        """Get an item from wrapped and convert it."""
        return self.__convert_value(self._wrapped[key])

    def __iter__(self) -> Iterable["InputResolver"]:
        """Iterate through wrapped, converting the results."""
        for each in self._wrapped:
            yield self.__convert_value(each)

    def __reversed__(self) -> Iterable["InputResolver"]:
        """Reverse wrapped, converting the result."""
        return self.__convert_value(reversed(self._wrapped))

    def __next__(self) -> "InputResolver":
        """Iterate through wrapped, converting the results."""
        return self.__convert_value(self._wrapped.__next__())

    def _get(self, key, default=None) -> "InputResolver":
        """Call wrapped.get, converting the result."""
        return self.__convert_value(self._wrapped.get(key, default))

    def _items(self) -> Iterable[Iterable["InputResolver"]]:
        """Call wrapped.items, converting the resulting keys and values."""
        for key, value in self._wrapped.items():
            yield (self.__convert_value(key), self.__convert_value(value))

    def _keys(self) -> Iterable["InputResolver"]:
        """Call wrapped.keys, converting the resulting keys."""
        for key in self._wrapped.keys():
            yield self.__convert_value(key)

    def _values(self) -> Iterable["InputResolver"]:
        """Call wrapped.values, converting the resulting values."""
        for value in self._wrapped.values():
            yield self.__convert_value(value)
Ejemplo n.º 24
0
    """
    Create a new attr attribute with validator for an atribute that is a dictionary with keys as str (to represent
    the name) and the content of an instance of type_
    """
    metadata = {"type": "dict_of_instance", "class_": type_}
    return attr.ib(
        default=attr.Factory(dict),
        validator=dict_of(type_),
        type=Dict[str, type_],
        metadata=metadata,
    )


dict_of_array = deep_mapping(
    key_validator=instance_of(str),
    value_validator=instance_of(Array),
    mapping_validator=instance_of(dict),
)
dict_with_scalar = deep_mapping(
    key_validator=instance_of(str),
    value_validator=instance_of(Scalar),
    mapping_validator=instance_of(dict),
)
list_of_numbers = deep_iterable(member_validator=instance_of(Number),
                                iterable_validator=instance_of((list, range)))
dict_with_a_list_of_numbers = deep_mapping(
    key_validator=instance_of(str),
    value_validator=list_of_numbers,
    mapping_validator=instance_of(dict),
)
Ejemplo n.º 25
0
class ObjectRecognizer:
    """
    The ObjectRecognizer finds object matches in the scene pattern and adds a `ObjectSemanticNodePerceptionPredicate`
    which can be used to learn additional semantics which relate objects to other objects

    If applied to a dynamic situation, this will only recognize objects
    which are present in both the BEFORE and AFTER frames.
    """

    # Because static patterns must be applied to static perceptions
    # and dynamic patterns to dynamic situations,
    # we need to store our patterns both ways.
    _concepts_to_static_patterns: ImmutableDict[
        ObjectConcept, PerceptionGraphPattern] = attrib(
            validator=deep_mapping(instance_of(ObjectConcept),
                                   instance_of(PerceptionGraphPattern)),
            converter=_to_immutabledict,
        )
    _concepts_to_names: ImmutableDict[ObjectConcept, str] = attrib(
        validator=deep_mapping(instance_of(ObjectConcept), instance_of(str)),
        converter=_to_immutabledict,
    )

    # We derive these from the static patterns.
    _concepts_to_dynamic_patterns: ImmutableDict[
        ObjectConcept, PerceptionGraphPattern] = attrib(init=False)
    determiners: ImmutableSet[str] = attrib(converter=_to_immutableset,
                                            validator=deep_iterable(
                                                instance_of(str)))
    """
    This is a hack to handle determiners.
    See https://github.com/isi-vista/adam/issues/498
    """
    _concept_to_num_subobjects: ImmutableDict[Concept,
                                              int] = attrib(init=False)
    """
    Used for a performance optimization in match_objects.
    """
    _language_mode: LanguageMode = attrib(validator=instance_of(LanguageMode),
                                          kw_only=True)

    def __attrs_post_init__(self) -> None:
        non_lowercase_determiners = [
            determiner for determiner in self.determiners
            if determiner.lower() != determiner
        ]
        if non_lowercase_determiners:
            raise RuntimeError(
                f"All determiners must be specified in lowercase, but got "
                f"{non_lowercase_determiners}")

    @staticmethod
    def for_ontology_types(
        ontology_types: Iterable[OntologyNode],
        determiners: Iterable[str],
        ontology: Ontology,
        language_mode: LanguageMode,
        *,
        perception_generator:
        HighLevelSemanticsSituationToDevelopmentalPrimitivePerceptionGenerator,
    ) -> "ObjectRecognizer":
        ontology_types_to_concepts = {
            obj_type: ObjectConcept(obj_type.handle)
            for obj_type in ontology_types
        }

        return ObjectRecognizer(
            concepts_to_static_patterns=_sort_mapping_by_pattern_complexity(
                immutabledict((
                    concept,
                    PerceptionGraphPattern.from_ontology_node(
                        obj_type,
                        ontology,
                        perception_generator=perception_generator),
                ) for (obj_type,
                       concept) in ontology_types_to_concepts.items())),
            determiners=determiners,
            concepts_to_names={
                concept: obj_type.handle
                for obj_type, concept in ontology_types_to_concepts.items()
            },
            language_mode=language_mode,
        )

    def match_objects_old(
        self, perception_graph: PerceptionGraph
    ) -> PerceptionGraphFromObjectRecognizer:
        new_style_input = PerceptionSemanticAlignment(
            perception_graph=perception_graph, semantic_nodes=[])
        new_style_output = self.match_objects(new_style_input)
        return PerceptionGraphFromObjectRecognizer(
            perception_graph=new_style_output[0].perception_graph,
            description_to_matched_object_node=new_style_output[1],
        )

    def match_objects(
        self,
        perception_semantic_alignment: PerceptionSemanticAlignment,
        *,
        post_process: Callable[[PerceptionGraph, AbstractSet[SemanticNode]],
                               Tuple[PerceptionGraph,
                                     AbstractSet[SemanticNode]],
                               ] = default_post_process_enrichment,
    ) -> Tuple[PerceptionSemanticAlignment, Mapping[Tuple[str, ...],
                                                    ObjectSemanticNode]]:
        r"""
        Recognize known objects in a `PerceptionGraph`.

        The matched portion of the graph will be replaced with an `ObjectSemanticNode`\ s
        which will inherit all relationships of any nodes internal to the matched portion
        with any external nodes.

        This is useful as a pre-processing step
        before prepositional and verbal learning experiments.
        """

        # pylint: disable=global-statement,invalid-name
        global cumulative_millis_in_successful_matches_ms
        global cumulative_millis_in_failed_matches_ms

        object_nodes: List[Tuple[Tuple[str, ...], ObjectSemanticNode]] = []
        perception_graph = perception_semantic_alignment.perception_graph
        is_dynamic = perception_semantic_alignment.perception_graph.dynamic

        if is_dynamic:
            concepts_to_patterns = self._concepts_to_dynamic_patterns
        else:
            concepts_to_patterns = self._concepts_to_static_patterns

        # We special case handling the ground perception
        # Because we don't want to remove it from the graph, we just want to use it's
        # Object node as a recognized object. The situation "a box on the ground"
        # Prompted the need to recognize the ground
        graph_to_return = perception_graph
        for node in graph_to_return._graph.nodes:  # pylint:disable=protected-access
            if node == GROUND_PERCEPTION:
                matched_object_node = ObjectSemanticNode(GROUND_OBJECT_CONCEPT)
                if LanguageMode.ENGLISH == self._language_mode:
                    object_nodes.append(
                        ((f"{GROUND_OBJECT_CONCEPT.debug_string}", ),
                         matched_object_node))
                elif LanguageMode.CHINESE == self._language_mode:
                    object_nodes.append((("di4 myan4", ), matched_object_node))
                else:
                    raise RuntimeError("Invalid language_generator")
                # We construct a fake match which is only the ground perception node
                subgraph_of_root = subgraph(perception_graph.copy_as_digraph(),
                                            [node])
                pattern_match = PerceptionGraphPatternMatch(
                    matched_pattern=PerceptionGraphPattern(
                        graph=subgraph_of_root,
                        dynamic=perception_graph.dynamic),
                    graph_matched_against=perception_graph,
                    matched_sub_graph=PerceptionGraph(
                        graph=subgraph_of_root,
                        dynamic=perception_graph.dynamic),
                    pattern_node_to_matched_graph_node=immutabledict(),
                )
                graph_to_return = replace_match_with_object_graph_node(
                    matched_object_node, graph_to_return, pattern_match)

        candidate_object_subgraphs = extract_candidate_objects(
            perception_graph)

        for candidate_object_graph in candidate_object_subgraphs:
            num_object_nodes = candidate_object_graph.count_nodes_matching(
                lambda node: isinstance(node, ObjectPerception))

            for (concept, pattern) in concepts_to_patterns.items():
                # As an optimization, we count how many sub-object nodes
                # are in the graph and the pattern.
                # If they aren't the same, the match is impossible
                # and we can bail out early.
                if num_object_nodes != self._concept_to_num_subobjects[concept]:
                    continue

                with Timer(factor=1000) as t:
                    matcher = pattern.matcher(candidate_object_graph,
                                              match_mode=MatchMode.OBJECT)
                    pattern_match = first(
                        matcher.matches(use_lookahead_pruning=True), None)
                if pattern_match:
                    cumulative_millis_in_successful_matches_ms += t.elapsed
                    matched_object_node = ObjectSemanticNode(concept)

                    # We wrap the concept in a tuple because it could in theory be multiple
                    # tokens,
                    # even though currently it never is.
                    if self._language_mode == LanguageMode.ENGLISH:
                        object_nodes.append(
                            ((concept.debug_string, ), matched_object_node))
                    elif self._language_mode == LanguageMode.CHINESE:
                        if concept.debug_string == "me":
                            object_nodes.append(
                                (("wo3", ), matched_object_node))
                        elif concept.debug_string == "you":
                            object_nodes.append(
                                (("ni3", ), matched_object_node))
                        mappings = (
                            GAILA_PHASE_1_CHINESE_LEXICON.
                            _ontology_node_to_word  # pylint:disable=protected-access
                        )
                        for k, v in mappings.items():
                            if k.handle == concept.debug_string:
                                debug_string = str(v.base_form)
                                object_nodes.append(
                                    ((debug_string, ), matched_object_node))
                    graph_to_return = replace_match_with_object_graph_node(
                        matched_object_node, graph_to_return, pattern_match)
                    # We match each candidate objects against only one object type.
                    # See https://github.com/isi-vista/adam/issues/627
                    break
                else:
                    cumulative_millis_in_failed_matches_ms += t.elapsed
        if object_nodes:
            logging.info(
                "Object recognizer recognized: %s",
                [concept for (concept, _) in object_nodes],
            )
        logging.info(
            "object matching: ms in success: %s, ms in failed: %s",
            cumulative_millis_in_successful_matches_ms,
            cumulative_millis_in_failed_matches_ms,
        )
        semantic_object_nodes = immutableset(node
                                             for (_, node) in object_nodes)

        post_process_graph, post_process_nodes = post_process(
            graph_to_return, semantic_object_nodes)

        return (
            perception_semantic_alignment.
            copy_with_updated_graph_and_added_nodes(
                new_graph=post_process_graph, new_nodes=post_process_nodes),
            immutabledict(object_nodes),
        )

    def match_objects_with_language_old(
        self, language_aligned_perception: LanguageAlignedPerception
    ) -> LanguageAlignedPerception:
        if language_aligned_perception.node_to_language_span:
            raise RuntimeError(
                "Don't know how to handle a non-empty node-to-language-span")
        new_style_input = LanguagePerceptionSemanticAlignment(
            language_concept_alignment=LanguageConceptAlignment(
                language_aligned_perception.language,
                node_to_language_span=[]),
            perception_semantic_alignment=PerceptionSemanticAlignment(
                perception_graph=language_aligned_perception.perception_graph,
                semantic_nodes=[],
            ),
        )
        new_style_output = self.match_objects_with_language(new_style_input)
        return LanguageAlignedPerception(
            language=new_style_output.language_concept_alignment.language,
            perception_graph=new_style_output.perception_semantic_alignment.
            perception_graph,
            node_to_language_span=new_style_output.language_concept_alignment.
            node_to_language_span,
        )

    def match_objects_with_language(
        self,
        language_perception_semantic_alignment:
        LanguagePerceptionSemanticAlignment,
        *,
        post_process: Callable[[PerceptionGraph, AbstractSet[SemanticNode]],
                               Tuple[PerceptionGraph,
                                     AbstractSet[SemanticNode]],
                               ] = default_post_process_enrichment,
    ) -> LanguagePerceptionSemanticAlignment:
        """
        Recognize known objects in a `LanguagePerceptionSemanticAlignment`.

        For each node matched, this will identify the relevant portion of the linguistic input
        and record the correspondence.

        The matched portion of the graph will be replaced with an `ObjectSemanticNode`
        which will inherit all relationships of any nodes internal to the matched portion
        with any external nodes.

        This is useful as a pre-processing step
        before prepositional and verbal learning experiments.
        """
        if (language_perception_semantic_alignment.
                perception_semantic_alignment.semantic_nodes):
            raise RuntimeError(
                "We assume ObjectRecognizer is run first, with no previous "
                "alignments")

        (
            post_match_perception_semantic_alignment,
            tokens_to_object_nodes,
        ) = self.match_objects(
            language_perception_semantic_alignment.
            perception_semantic_alignment,
            post_process=post_process,
        )
        return LanguagePerceptionSemanticAlignment(
            language_concept_alignment=language_perception_semantic_alignment.
            language_concept_alignment.copy_with_added_token_alignments(
                self._align_objects_to_tokens(
                    tokens_to_object_nodes,
                    language_perception_semantic_alignment.
                    language_concept_alignment.language,
                )),
            perception_semantic_alignment=
            post_match_perception_semantic_alignment,
        )

    def _align_objects_to_tokens(
        self,
        description_to_object_node: Mapping[Tuple[str, ...],
                                            ObjectSemanticNode],
        language: LinguisticDescription,
    ) -> Mapping[ObjectSemanticNode, Span]:
        result: List[Tuple[ObjectSemanticNode, Span]] = []

        # We want to ban the same token index from being aligned twice.
        matched_token_indices: Set[int] = set()

        for (description_tuple,
             object_node) in description_to_object_node.items():
            if len(description_tuple) != 1:
                raise RuntimeError(
                    f"Multi-token descriptions are not yet supported:"
                    f"{description_tuple}")
            description = description_tuple[0]
            try:
                end_index_inclusive = language.index(description)
            except ValueError:
                # A scene might contain things which are not referred to by the associated language.
                continue

            start_index = end_index_inclusive
            # This is a somewhat language-dependent hack to gobble up preceding determiners.
            # See https://github.com/isi-vista/adam/issues/498 .
            if end_index_inclusive > 0:
                possible_determiner_index = end_index_inclusive - 1
                if language[possible_determiner_index].lower(
                ) in self.determiners:
                    start_index = possible_determiner_index

            # We record what tokens were covered so we can block the same tokens being used twice.
            for included_token_index in range(start_index,
                                              end_index_inclusive + 1):
                if included_token_index in matched_token_indices:
                    raise RuntimeError(
                        "We do not currently support the same object "
                        "being mentioned twice in a sentence.")
                matched_token_indices.add(included_token_index)

            result.append((
                object_node,
                language.span(start_index,
                              end_index_exclusive=end_index_inclusive + 1),
            ))
        return immutabledict(result)

    @_concepts_to_dynamic_patterns.default
    def _init_concepts_to_dynamic_patterns(
            self) -> ImmutableDict[ObjectConcept, PerceptionGraphPattern]:
        return immutabledict(
            (concept, static_pattern.copy_with_temporal_scopes(ENTIRE_SCENE))
            for (concept,
                 static_pattern) in self._concepts_to_static_patterns.items())

    @_concept_to_num_subobjects.default
    def _init_patterns_to_num_subobjects(
            self) -> ImmutableDict[ObjectConcept, int]:
        return immutabledict((
            concept,
            pattern.count_nodes_matching(
                lambda node: isinstance(node, AnyObjectPerception)),
        ) for (concept, pattern) in self._concepts_to_static_patterns.items())
Ejemplo n.º 26
0
class SituationObject(HasAxes):
    r"""
    An object present in some situation.

    Every object must refer to an `OntologyNode` linking it to a type in an ontology.

    Unlike most of our classes, `SituationObject` has *id*-based hashing and equality.  This is
    because two objects with identical properties are nonetheless distinct.

    `SituationObject`\ should not be directly instantiated.
    Instead use `instantiate_ontology_node`.
    """

    ontology_node: OntologyNode = attrib(validator=instance_of(OntologyNode))
    """
    The `OntologyNode` specifying the type of thing this object is.
    """
    # note to readers: the two axis-related fields have to be placed first to make PyCharm happy,
    # but you should read the other fields first.
    axes: Axes = attrib(validator=instance_of(Axes), kw_only=True)
    schema_axis_to_object_axis: Mapping[GeonAxis, GeonAxis] = attrib(
        validator=deep_mapping(instance_of(GeonAxis), instance_of(GeonAxis)),
        kw_only=True,
        converter=_to_immutabledict,
    )
    """
    Provides a mapping between the axes of an `ObjectStructuralSchema` 
    (which are abstract and generic - i.e. the axes of "tire"s in general)
    and the concrete instantiations of those axes which are stored in the *axes* field
    of this object.
    We need to track this information to keep the object axes in sync with the `Geon` axes
    during perceptual generation.

    Note that this mapping may be empty if this situation object was not derived from 
    an `ObjectStructuralSchema`.

    Rather than setting this field by hand, we recommend using the static factory method
    `from_structural_schema`.
    """
    properties: ImmutableSet[OntologyNode] = attrib(converter=_to_immutableset,
                                                    default=immutableset(),
                                                    kw_only=True)
    r"""
    The `OntologyNode`\ s representing the properties this object has.
    """
    debug_handle: str = attrib(validator=instance_of(str), kw_only=True)

    def __attrs_post_init__(self) -> None:
        # disabled warning below is due to a PyCharm bug
        # noinspection PyTypeChecker
        for property_ in self.properties:
            if not isinstance(property_, OntologyNode):
                raise ValueError(
                    f"Situation object property {property_} is not an "
                    f"OntologyNode")
        for concrete_axis in self.schema_axis_to_object_axis.values():
            check_arg(concrete_axis in self.axes.all_axes)
        # Every object should either have axes mapped to the axes of a schema object,
        # or should have WORLD_AXES, which is what we use by default for things
        # like substances which have no particular shape.
        check_arg(
            self.schema_axis_to_object_axis or self.axes == WORLD_AXES,
            "Axes must be aligned to a scheme or else be WORLD_AXES",
        )

    @debug_handle.default
    def _default_debug_handle(self) -> str:
        return f"{self.ontology_node.handle}"

    def __repr__(self) -> str:
        if self.properties:
            additional_properties = ", ".join(
                repr(prop) for prop in self.properties)
            additional_properties_string = f"[{additional_properties}]"
        else:
            additional_properties_string = ""

        if self.ontology_node and not self.debug_handle.startswith(
                self.ontology_node.handle):
            handle_string = f"[{self.ontology_node.handle}]"
        else:
            handle_string = ""

        return f"{self.debug_handle}{handle_string}{additional_properties_string}"

    @staticmethod
    def instantiate_ontology_node(
        ontology_node: OntologyNode,
        *,
        properties: Iterable[OntologyNode] = immutableset(),
        debug_handle: Optional[str] = None,
        ontology: Ontology,
    ) -> "SituationObject":
        """
        Make a `SituationObject` from the object type *ontology_node*
        with properties *properties*.
        The *properties* and ontology node must all come from *ontology*.
        """
        schema_axis_to_object_axis: Mapping[GeonAxis, GeonAxis]
        if ontology.has_property(ontology_node, IS_SUBSTANCE):
            # it's not clear what the object_concrete_axes should be for substances,
            # so we just use the world object_concrete_axes for now
            schema_axis_to_object_axis = immutabledict()
            object_concrete_axes = WORLD_AXES
        else:
            structural_schemata = ontology.structural_schemata(ontology_node)
            if not structural_schemata:
                raise RuntimeError(
                    f"No structural schema found for {ontology_node}")
            if len(structural_schemata) > 1:
                raise RuntimeError(
                    f"Multiple structural schemata available for {ontology_node}, "
                    f"please construct the SituationObject manually: "
                    f"{structural_schemata}")
            schema_abstract_axes = only(structural_schemata).axes
            # The copy is needed or e.g. all tires of a truck
            # would share the same axis objects.
            object_concrete_axes = schema_abstract_axes.copy()
            schema_axis_to_object_axis = immutabledict(
                zip(schema_abstract_axes.all_axes,
                    object_concrete_axes.all_axes))

        return SituationObject(
            ontology_node=ontology_node,
            properties=properties,
            debug_handle=debug_handle
            if debug_handle else ontology_node.handle,
            axes=object_concrete_axes,
            schema_axis_to_object_axis=schema_axis_to_object_axis,
        )
Ejemplo n.º 27
0
class PerceptionGraphTemplate:
    graph_pattern: PerceptionGraphPattern = attrib(
        validator=instance_of(PerceptionGraphPattern), kw_only=True
    )
    template_variable_to_pattern_node: ImmutableDict[
        SyntaxSemanticsVariable, ObjectSemanticNodePerceptionPredicate
    ] = attrib(
        converter=_to_immutabledict,
        kw_only=True,
        validator=deep_mapping(
            instance_of(SyntaxSemanticsVariable),
            instance_of(ObjectSemanticNodePerceptionPredicate),
        ),
        default=immutabledict(),
    )
    pattern_node_to_template_variable: ImmutableDict[
        ObjectSemanticNodePerceptionPredicate, SyntaxSemanticsVariable
    ] = attrib(init=False)

    @staticmethod
    def from_graph(
        perception_graph: PerceptionGraph,
        template_variable_to_matched_object_node: Mapping[
            SyntaxSemanticsVariable, ObjectSemanticNode
        ],
    ) -> "PerceptionGraphTemplate":
        # It is possible the perception graph has additional recognized objects
        # which are not aligned to surface template slots.
        # We assume these are not arguments of the verb and remove them from the perception
        # before creating a pattern.
        pattern_from_graph = PerceptionGraphPattern.from_graph(perception_graph)
        pattern_graph = pattern_from_graph.perception_graph_pattern
        matched_object_to_matched_predicate = (
            pattern_from_graph.perception_graph_node_to_pattern_node
        )

        template_variable_to_pattern_node: List[Any] = []

        for (
            template_variable,
            object_node,
        ) in template_variable_to_matched_object_node.items():
            if object_node in matched_object_to_matched_predicate:
                template_variable_to_pattern_node.append(
                    (template_variable, matched_object_to_matched_predicate[object_node])
                )

        return PerceptionGraphTemplate(
            graph_pattern=pattern_graph,
            template_variable_to_pattern_node=template_variable_to_pattern_node,
        )

    def __attrs_post_init__(self) -> None:
        object_predicate_nodes = set(self.template_variable_to_pattern_node.values())

        for object_node in object_predicate_nodes:
            if (
                object_node
                not in self.graph_pattern._graph.nodes  # pylint:disable=protected-access
            ):
                raise RuntimeError(
                    f"Expected mapping which contained graph nodes"
                    f" but got {object_node} with id {id(object_node)}"
                    f" which doesn't exist in {self.graph_pattern}"
                )

    def intersection(
        self,
        pattern: "PerceptionGraphTemplate",
        *,
        graph_logger: Optional[GraphLogger] = None,
        ontology: Ontology,
        debug_callback: Optional[Callable[[Any, Any], None]] = None,
        allowed_matches: ImmutableSetMultiDict[
            NodePredicate, NodePredicate
        ] = immutablesetmultidict(),
        match_mode: MatchMode,
        trim_after_match: Optional[
            Callable[[PerceptionGraphPattern], PerceptionGraphPattern]
        ] = None,
    ) -> Optional["PerceptionGraphTemplate"]:
        r"""
        Gets the `PerceptionGraphTemplate` which contains all aspects of a pattern
        which are both in this template and *other_template*.

        If this intersection is an empty graph or would not contain all `SyntaxSemanticsVariable`\ s,
        this returns None.
        """
        if self.graph_pattern.dynamic != pattern.graph_pattern.dynamic:
            raise RuntimeError("Can only intersection patterns of the same dynamic-ness")

        num_self_weakly_connected = number_weakly_connected_components(
            self.graph_pattern._graph  # pylint:disable=protected-access
        )
        if num_self_weakly_connected > 1:
            raise_graph_exception(
                f"Graph pattern contains multiple ( {num_self_weakly_connected} ) "
                f"weakly connected components heading into intersection. ",
                self.graph_pattern,
            )

        num_pattern_weakly_connected = number_weakly_connected_components(
            pattern.graph_pattern._graph  # pylint:disable=protected-access
        )
        if num_pattern_weakly_connected > 1:
            raise_graph_exception(
                f"Graph pattern contains multiple ( {num_pattern_weakly_connected} ) "
                f"weakly connected components heading into intersection. ",
                pattern.graph_pattern,
            )

        # First we just intersect the pattern graph.
        intersected_pattern = self.graph_pattern.intersection(
            pattern.graph_pattern,
            graph_logger=graph_logger,
            ontology=ontology,
            debug_callback=debug_callback,
            allowed_matches=allowed_matches,
            match_mode=match_mode,
            trim_after_match=trim_after_match,
        )

        if intersected_pattern:
            if self.graph_pattern.dynamic != intersected_pattern.dynamic:
                raise RuntimeError(
                    "Something is wrong - pattern dynamic-ness should not change "
                    "after intersection"
                )

            # If we get a successful intersection,
            # we then need to make sure we have the correct SyntaxSemanticsVariables.

            # It would be more intuitive to use self.template_variable_to_pattern_node,
            # but the pattern intersection code seems to prefer to return nodes
            # from the right-hand graph.
            template_variable_to_pattern_node = pattern.template_variable_to_pattern_node
            if graph_logger:
                graph_logger.log_graph(intersected_pattern, INFO, "Intersected pattern")
            for (_, object_wildcard) in template_variable_to_pattern_node.items():
                # we return none here since this means that the given template cannot be learned from since one of the slots has been pruned away
                if object_wildcard not in intersected_pattern:
                    return None

            return PerceptionGraphTemplate(
                graph_pattern=intersected_pattern,
                template_variable_to_pattern_node=template_variable_to_pattern_node,
            )
        else:
            return None

    @pattern_node_to_template_variable.default
    def _init_pattern_node_to_template_variable(
        self
    ) -> ImmutableDict[ObjectSemanticNodePerceptionPredicate, SyntaxSemanticsVariable]:
        return immutabledict(
            {v: k for k, v in self.template_variable_to_pattern_node.items()}
        )

    def render_to_file(  # pragma: no cover
        self,
        graph_name: str,
        output_file: Path,
        *,
        match_correspondence_ids: Mapping[Any, str] = immutabledict(),
        robust=True,
    ):
        self.graph_pattern.render_to_file(
            graph_name,
            output_file,
            match_correspondence_ids=match_correspondence_ids,
            robust=robust,
            replace_node_labels=immutabledict(
                (pattern_node, template_variable.name)
                for (
                    pattern_node,
                    template_variable,
                ) in self.pattern_node_to_template_variable.items()
            ),
        )

    def copy_with_temporal_scopes(
        self, required_temporal_scopes: Union[TemporalScope, Iterable[TemporalScope]]
    ) -> "PerceptionGraphTemplate":
        r"""
        Produces a copy of this perception graph pattern
        where all edge predicates now require that the edge in the target graph being matched
        hold at all of the *required_temporal_scopes*.
        """
        return PerceptionGraphTemplate(
            graph_pattern=self.graph_pattern.copy_with_temporal_scopes(
                required_temporal_scopes
            ),
            template_variable_to_pattern_node=self.template_variable_to_pattern_node,
        )
class MdParserConfig:
    """Configuration options for the Markdown Parser.

    Note in the sphinx configuration these option names are prepended with ``myst_``
    """

    renderer: str = attr.ib(
        default="sphinx", validator=in_(["sphinx", "html", "docutils"])
    )
    commonmark_only: bool = attr.ib(default=False, validator=instance_of(bool))
    enable_extensions: Iterable[str] = attr.ib(factory=lambda: ["dollarmath"])

    dmath_allow_labels: bool = attr.ib(default=True, validator=instance_of(bool))
    dmath_allow_space: bool = attr.ib(default=True, validator=instance_of(bool))
    dmath_allow_digits: bool = attr.ib(default=True, validator=instance_of(bool))
    dmath_double_inline: bool = attr.ib(default=False, validator=instance_of(bool))

    update_mathjax: bool = attr.ib(default=True, validator=instance_of(bool))

    mathjax_classes: str = attr.ib(
        default="tex2jax_process|mathjax_process|math|output_area",
        validator=instance_of(str),
    )

    @enable_extensions.validator
    def check_extensions(self, attribute, value):
        if not isinstance(value, Iterable):
            raise TypeError(f"myst_enable_extensions not iterable: {value}")
        diff = set(value).difference(
            [
                "dollarmath",
                "amsmath",
                "deflist",
                "html_admonition",
                "html_image",
                "colon_fence",
                "smartquotes",
                "replacements",
                "linkify",
                "substitution",
                "tasklist",
            ]
        )
        if diff:
            raise ValueError(f"myst_enable_extensions not recognised: {diff}")

    disable_syntax: Iterable[str] = attr.ib(
        factory=list,
        validator=deep_iterable(instance_of(str), instance_of((list, tuple))),
    )

    # see https://en.wikipedia.org/wiki/List_of_URI_schemes
    url_schemes: Optional[Iterable[str]] = attr.ib(
        default=cast(Optional[Iterable[str]], ("http", "https", "mailto", "ftp")),
        validator=optional(deep_iterable(instance_of(str), instance_of((list, tuple)))),
    )

    heading_anchors: Optional[int] = attr.ib(
        default=None, validator=optional(in_([1, 2, 3, 4, 5, 6, 7]))
    )

    heading_slug_func: Optional[Callable[[str], str]] = attr.ib(
        default=None, validator=optional(is_callable())
    )

    html_meta: Dict[str, str] = attr.ib(
        factory=dict,
        validator=deep_mapping(instance_of(str), instance_of(str), instance_of(dict)),
        repr=lambda v: str(list(v)),
    )

    footnote_transition: bool = attr.ib(default=True, validator=instance_of(bool))

    substitutions: Dict[str, Union[str, int, float]] = attr.ib(
        factory=dict,
        validator=deep_mapping(
            instance_of(str), instance_of((str, int, float)), instance_of(dict)
        ),
        repr=lambda v: str(list(v)),
    )

    sub_delimiters: Tuple[str, str] = attr.ib(default=("{", "}"))

    words_per_minute: int = attr.ib(default=200, validator=instance_of(int))

    @sub_delimiters.validator
    def check_sub_delimiters(self, attribute, value):
        if (not isinstance(value, (tuple, list))) or len(value) != 2:
            raise TypeError(f"myst_sub_delimiters is not a tuple of length 2: {value}")
        for delim in value:
            if (not isinstance(delim, str)) or len(delim) != 1:
                raise TypeError(
                    f"myst_sub_delimiters does not contain strings of length 1: {value}"
                )

    @classmethod
    def get_fields(cls) -> Tuple[attr.Attribute, ...]:
        return attr.fields(cls)

    def as_dict(self, dict_factory=dict) -> dict:
        return attr.asdict(self, dict_factory=dict_factory)
Ejemplo n.º 29
0
class MdParserConfig:
    """Configuration options for the Markdown Parser.

    Note in the sphinx configuration these option names are prepended with ``myst_``
    """

    renderer: str = attr.ib(default="sphinx",
                            validator=in_(["sphinx", "html", "docutils"]))
    commonmark_only: bool = attr.ib(default=False, validator=instance_of(bool))
    dmath_allow_labels: bool = attr.ib(default=True,
                                       validator=instance_of(bool))
    dmath_allow_space: bool = attr.ib(default=True,
                                      validator=instance_of(bool))
    dmath_allow_digits: bool = attr.ib(default=True,
                                       validator=instance_of(bool))

    update_mathjax: bool = attr.ib(default=True, validator=instance_of(bool))

    # TODO remove deprecated _enable attributes after v0.13.0
    admonition_enable: bool = attr.ib(default=False,
                                      validator=instance_of(bool),
                                      repr=False)
    figure_enable: bool = attr.ib(default=False,
                                  validator=instance_of(bool),
                                  repr=False)
    dmath_enable: bool = attr.ib(default=False,
                                 validator=instance_of(bool),
                                 repr=False)
    amsmath_enable: bool = attr.ib(default=False,
                                   validator=instance_of(bool),
                                   repr=False)
    deflist_enable: bool = attr.ib(default=False,
                                   validator=instance_of(bool),
                                   repr=False)
    html_img_enable: bool = attr.ib(default=False,
                                    validator=instance_of(bool),
                                    repr=False)
    colon_fence_enable: bool = attr.ib(default=False,
                                       validator=instance_of(bool),
                                       repr=False)

    enable_extensions: Iterable[str] = attr.ib(factory=lambda: ["dollarmath"])

    @enable_extensions.validator
    def check_extensions(self, attribute, value):
        if not isinstance(value, Iterable):
            raise TypeError(f"myst_enable_extensions not iterable: {value}")
        diff = set(value).difference([
            "dollarmath",
            "amsmath",
            "deflist",
            "html_image",
            "colon_fence",
            "smartquotes",
            "replacements",
            "linkify",
            "substitution",
        ])
        if diff:
            raise ValueError(f"myst_enable_extensions not recognised: {diff}")

    disable_syntax: List[str] = attr.ib(
        factory=list,
        validator=deep_iterable(instance_of(str), instance_of((list, tuple))),
    )

    # see https://en.wikipedia.org/wiki/List_of_URI_schemes
    url_schemes: Optional[List[str]] = attr.ib(
        default=None,
        validator=optional(
            deep_iterable(instance_of(str), instance_of((list, tuple)))),
    )

    heading_anchors: Optional[int] = attr.ib(default=None,
                                             validator=optional(
                                                 in_([1, 2, 3, 4, 5, 6, 7])))

    substitutions: Dict[str, str] = attr.ib(
        factory=dict,
        validator=deep_mapping(instance_of(str), instance_of(
            (str, int, float)), instance_of(dict)),
        repr=lambda v: str(list(v)),
    )

    sub_delimiters: Tuple[str, str] = attr.ib(default=("{", "}"))

    @sub_delimiters.validator
    def check_sub_delimiters(self, attribute, value):
        if (not isinstance(value, (tuple, list))) or len(value) != 2:
            raise TypeError(
                f"myst_sub_delimiters is not a tuple of length 2: {value}")
        for delim in value:
            if (not isinstance(delim, str)) or len(delim) != 1:
                raise TypeError(
                    f"myst_sub_delimiters does not contain strings of length 1: {value}"
                )

    def as_dict(self, dict_factory=dict) -> dict:
        return attr.asdict(self, dict_factory=dict_factory)
Ejemplo n.º 30
0
class CloudFormationPhysicalResourceCache:
    """Cache for persistent information about CloudFormation stack resources.

    :param stack_name: Name of target stack
    :param cache: Pre-populated cache mapping logical resource names to physical resource names (optional)
    """

    _client = attr.ib()
    _stack_name: str = attr.ib(validator=instance_of(str))
    _cache: Dict[str, Dict] = attr.ib(
        default=attr.Factory(dict),
        validator=deep_mapping(key_validator=instance_of(str),
                               value_validator=instance_of(str)),
    )

    def _describe_resource(self, logical_resource_name: str) -> Dict:
        """Describe the requested resource.

        :param logical_resource_name: Logical resource name of resource to describe
        :returns: result from ``describe_stack_resource`` call
        """
        return self._client.describe_stack_resource(
            StackName=self._stack_name,
            LogicalResourceId=logical_resource_name)

    def wait_until_resource_is_complete(self, logical_resource_name: str):
        """Wait until the specified resource is complete.

        :param logical_resource_name: Logical resource name of resource
        """
        response = self.wait_until_resource_exists_in_stack(
            logical_resource_name)
        if not response["StackResourceDetail"].get("ResourceStatus", ""):
            response = self._wait_until_field_exists(logical_resource_name,
                                                     "ResourceStatus")
        while True:
            status = response["StackResourceDetail"]["ResourceStatus"]
            _LOGGER.debug("Status of resource %s in stack %s is %s",
                          logical_resource_name, self._stack_name, status)

            if status in ("CREATE_COMPLETE", "UPDATE_COMPLETE"):
                break
            elif status in ("CREATE_IN_PROGRESS", "UPDATE_IN_PROGRESS"):
                time.sleep(5)
                response = self._describe_resource(logical_resource_name)
            else:
                raise Exception(
                    f'Resource creation failed. Resource "{logical_resource_name}" status: "{status}"'
                )

    def wait_until_resource_exists_in_stack(
            self, logical_resource_name: str) -> Dict:
        """Wait until the specified resource exists.

        :param logical_resource_name: Logical resource name of resource
        """
        resource_attempts = 1
        while True:
            _LOGGER.debug(
                "Waiting for creation of resource %s in stack %s to start. Attempt %d of %d",
                logical_resource_name,
                self._stack_name,
                resource_attempts,
                MAX_RESOURCE_ATTEMPTS,
            )
            try:
                return self._describe_resource(logical_resource_name)
            except ClientError as error:
                _LOGGER.debug('Encountered botocore ClientError: "%s"',
                              error.response["Error"]["Message"])
                if (error.response["Error"]["Message"] ==
                        f"Resource {logical_resource_name} does not exist for stack {self._stack_name}"
                    ):
                    resource_attempts += 1

                    if resource_attempts > MAX_RESOURCE_ATTEMPTS:
                        raise
                else:
                    raise

            time.sleep(WAIT_PER_ATTEMPT)

    def _wait_until_field_exists(self, logical_resource_name: str,
                                 field_name: str) -> Dict:
        """Keep trying to describe a resource until it has the requested field.

        Wait 5 seconds between attempts.

        :param logical_resource_name: Logical resource name of resource
        :param field_name: Field in resource details to wait for
        :returns: results from ``describe_stack_resource`` call
        """
        resource_attempts = 1
        response = self.wait_until_resource_exists_in_stack(
            logical_resource_name)
        while not response.get("StackResourceDetail", {}).get(field_name, ""):
            time.sleep(WAIT_PER_ATTEMPT)

            _LOGGER.debug(
                "Waiting for resource %s in stack %s to have a value for field %s. Attempt %d of %d",
                logical_resource_name,
                self._stack_name,
                field_name,
                resource_attempts,
                MAX_RESOURCE_ATTEMPTS,
            )
            response = self._describe_resource(logical_resource_name)

        return response

    def physical_resource_name(self, logical_resource_name: str) -> str:
        """Find the physical resource name given its logical resource name.

        If the resource does not exist yet, wait until it does.

        :param logical_resource_name: Logical resource name of resource
        """
        try:
            response = self._cache[logical_resource_name]  # attrs confuses pylint: disable=unsubscriptable-object
        except KeyError:
            response = self._wait_until_field_exists(
                logical_resource_name=logical_resource_name,
                field_name="PhysicalResourceId")
            self._cache[  # attrs confuses pylint: disable=unsupported-assignment-operation
                logical_resource_name] = response

        return response["StackResourceDetail"]["PhysicalResourceId"]