示例#1
0
async def test_asyncstate_with_list_of_valid_states():
    """Test various cases for AsyncState."""
    states = [1, 2, 3]
    state = AsyncState(1, states)

    state.set(2)
    assert state.get() == 2

    with pytest.raises(ValueError):
        state.set("anything")

    assert state.get() == 2
示例#2
0
async def test_async_state_transit():
    """Test async state transit contextmanager."""
    state = AsyncState()
    state.set(None)

    with state.transit(initial=1, success=2, fail=3):
        assert state.get() == 1
    assert state.get() == 2

    state.set(None)

    with suppress(ValueError):
        with state.transit(initial=1, success=2, fail=3):
            assert state.get() == 1
            raise ValueError()

    assert state.get() == 3
示例#3
0
async def test_asyncstate_callback():
    """Test various cases for AsyncState.callback."""
    state = AsyncState()

    called = False

    def callback_err(state):
        raise Exception("expected")

    def callback(state):
        nonlocal called
        called = True

    state.add_callback(callback_err)
    state.add_callback(callback)

    state.set(2)
    assert state.get() == 2
    assert called
示例#4
0
async def test_async_state():
    """Test various cases for AsyncState."""
    loop = asyncio.get_event_loop()
    state = AsyncState()

    # check set/get
    value = 1
    state.set(value)
    assert state.get() == value

    # check set/get with state property
    value = 3
    state.state = 3
    assert state.state == value

    # check wait/set
    loop.call_soon(state.set, 2)
    await state.wait(2)

    # state is already set
    await state.wait(2)
示例#5
0
class Connection(Component, ABC):
    """Abstract definition of a connection."""

    connection_id = None  # type: PublicId

    def __init__(
        self,
        configuration: ConnectionConfig,
        data_dir: str,
        identity: Optional[Identity] = None,
        crypto_store: Optional[CryptoStore] = None,
        restricted_to_protocols: Optional[Set[PublicId]] = None,
        excluded_protocols: Optional[Set[PublicId]] = None,
        **kwargs: Any,
    ) -> None:
        """
        Initialize the connection.

        The configuration must be specified if and only if the following
        parameters are None: connection_id, excluded_protocols or restricted_to_protocols.

        :param configuration: the connection configuration.
        :param data_dir: directory where to put local files.
        :param identity: the identity object held by the agent.
        :param crypto_store: the crypto store for encrypted communication.
        :param restricted_to_protocols: the set of protocols ids of the only supported protocols for this connection.
        :param excluded_protocols: the set of protocols ids that we want to exclude for this connection.
        """
        enforce(configuration is not None, "The configuration must be provided.")
        super().__init__(configuration, **kwargs)
        enforce(
            super().public_id == self.connection_id,
            "Connection ids in configuration and class not matching.",
        )
        self._state = AsyncState(ConnectionStates.disconnected)

        self._identity = identity
        self._crypto_store = crypto_store
        self._data_dir = data_dir

        self._restricted_to_protocols = (
            restricted_to_protocols if restricted_to_protocols is not None else set()
        )
        self._excluded_protocols = (
            excluded_protocols if excluded_protocols is not None else set()
        )

    @property
    def loop(self) -> asyncio.AbstractEventLoop:
        """Get the event loop."""
        enforce(asyncio.get_event_loop().is_running(), "Event loop is not running.")
        return asyncio.get_event_loop()

    def _ensure_connected(self) -> None:  # pragma: nocover
        """Raise exception if connection is not connected."""
        if not self.is_connected:
            raise ConnectionError("Connection is not connected! Connect first!")

    @staticmethod
    def _ensure_valid_envelope_for_external_comms(envelope: "Envelope") -> None:
        """
        Ensure the envelope sender and to are valid addresses for agent-to-agent communication.

        :param envelope: the envelope
        """
        enforce(
            not envelope.is_sender_public_id,
            f"Sender field of envelope is public id, needs to be address. Found={envelope.sender}",
        )
        enforce(
            not envelope.is_to_public_id,
            f"To field of envelope is public id, needs to be address. Found={envelope.to}",
        )

    @contextmanager
    def _connect_context(self) -> Generator:
        """Set state connecting, disconnecteing, dicsconnected during connect method."""
        with self._state.transit(
            initial=ConnectionStates.connecting,
            success=ConnectionStates.connected,
            fail=ConnectionStates.disconnected,
        ):
            yield

    @property
    def address(self) -> "Address":  # pragma: nocover
        """Get the address."""
        if self._identity is None:
            raise ValueError(
                "You must provide the identity in order to retrieve the address."
            )
        return self._identity.address

    @property
    def crypto_store(self) -> CryptoStore:  # pragma: nocover
        """Get the crypto store."""
        if self._crypto_store is None:
            raise ValueError("CryptoStore not available.")
        return self._crypto_store

    @property
    def has_crypto_store(self) -> bool:  # pragma: nocover
        """Check if the connection has the crypto store."""
        return self._crypto_store is not None

    @property
    def data_dir(self) -> str:  # pragma: nocover
        """Get the data directory."""
        return self._data_dir

    @property
    def component_type(self) -> ComponentType:  # pragma: nocover
        """Get the component type."""
        return ComponentType.CONNECTION

    @property
    def configuration(self) -> ConnectionConfig:
        """Get the connection configuration."""
        if self._configuration is None:  # pragma: nocover
            raise ValueError("Configuration not set.")
        return cast(ConnectionConfig, super().configuration)

    @property
    def restricted_to_protocols(self) -> Set[PublicId]:  # pragma: nocover
        """Get the ids of the protocols this connection is restricted to."""
        if self._configuration is None:
            return self._restricted_to_protocols
        return self.configuration.restricted_to_protocols

    @property
    def excluded_protocols(self) -> Set[PublicId]:  # pragma: nocover
        """Get the ids of the excluded protocols for this connection."""
        if self._configuration is None:
            return self._excluded_protocols
        return self.configuration.excluded_protocols

    @property
    def state(self) -> ConnectionStates:
        """Get the connection status."""
        return self._state.get()

    @state.setter
    def state(self, value: ConnectionStates) -> None:
        """Set the connection status."""
        if not isinstance(value, ConnectionStates):
            raise ValueError(f"Incorrect state: `{value}`")
        self._state.set(value)

    @abstractmethod
    async def connect(self) -> None:
        """Set up the connection."""

    @abstractmethod
    async def disconnect(self) -> None:
        """Tear down the connection."""

    @abstractmethod
    async def send(self, envelope: "Envelope") -> None:
        """
        Send an envelope.

        :param envelope: the envelope to send.
        :return: None
        """

    @abstractmethod
    async def receive(self, *args: Any, **kwargs: Any) -> Optional["Envelope"]:
        """
        Receive an envelope.

        :return: the received envelope, or None if an error occurred.
        """

    @classmethod
    def from_dir(
        cls,
        directory: str,
        identity: Identity,
        crypto_store: CryptoStore,
        data_dir: str,
        **kwargs: Any,
    ) -> "Connection":
        """
        Load the connection from a directory.

        :param directory: the directory to the connection package.
        :param identity: the identity object.
        :param crypto_store: object to access the connection crypto objects.
        :param data_dir: the assets directory.
        :return: the connection object.
        """
        configuration = cast(
            ConnectionConfig,
            load_component_configuration(ComponentType.CONNECTION, Path(directory)),
        )
        configuration.directory = Path(directory)
        return Connection.from_config(
            configuration, identity, crypto_store, data_dir, **kwargs
        )

    @classmethod
    def from_config(
        cls,
        configuration: ConnectionConfig,
        identity: Identity,
        crypto_store: CryptoStore,
        data_dir: str,
        **kwargs: Any,
    ) -> "Connection":
        """
        Load a connection from a configuration.

        :param configuration: the connection configuration.
        :param identity: the identity object.
        :param crypto_store: object to access the connection crypto objects.
        :param data_dir: the directory of the AEA project data.
        :return: an instance of the concrete connection class.
        """
        configuration = cast(ConnectionConfig, configuration)
        directory = cast(Path, configuration.directory)
        load_aea_package(configuration)
        connection_module_path = directory / "connection.py"
        if not (connection_module_path.exists() and connection_module_path.is_file()):
            raise AEAComponentLoadException(
                "Connection module '{}' not found.".format(connection_module_path)
            )
        connection_module = load_module(
            "connection_module", directory / "connection.py"
        )
        classes = inspect.getmembers(connection_module, inspect.isclass)
        connection_class_name = cast(str, configuration.class_name)
        connection_classes = list(
            filter(lambda x: re.match(connection_class_name, x[0]), classes)
        )
        name_to_class = dict(connection_classes)
        logger = get_logger(__name__, identity.name)
        logger.debug("Processing connection {}".format(connection_class_name))
        connection_class = name_to_class.get(connection_class_name, None)
        if connection_class is None:
            raise AEAComponentLoadException(
                "Connection class '{}' not found.".format(connection_class_name)
            )
        try:
            connection = connection_class(
                configuration=configuration,
                data_dir=data_dir,
                identity=identity,
                crypto_store=crypto_store,
                **kwargs,
            )
        except Exception as e:  # pragma: nocover # pylint: disable=broad-except
            e_str = parse_exception(e)
            raise AEAInstantiationException(
                f"An error occured during instantiation of connection {configuration.public_id}/{configuration.class_name}:\n{e_str}"
            )
        return connection

    @property
    def is_connected(self) -> bool:  # pragma: nocover
        """Return is connected state."""
        return self.state == ConnectionStates.connected

    @property
    def is_connecting(self) -> bool:  # pragma: nocover
        """Return is connecting state."""
        return self.state == ConnectionStates.connecting

    @property
    def is_disconnected(self) -> bool:  # pragma: nocover
        """Return is disconnected state."""
        return self.state == ConnectionStates.disconnected
示例#6
0
class Storage(Runnable):
    """Generic storage."""
    def __init__(
        self,
        storage_uri: str,
        loop: asyncio.AbstractEventLoop = None,
        threaded: bool = False,
    ) -> None:
        """
        Init stortage.

        :param storage_uri: configuration string for storage.
        :param loop: asyncio event loop to use.
        :param threaded: bool. start in thread if True.

        :return: None
        """
        super().__init__(loop=loop, threaded=threaded)
        self._storage_uri = storage_uri
        self._backend: AbstractStorageBackend = self._get_backend_instance(
            storage_uri)
        self._is_connected = False
        self._connected_state = AsyncState(False)

    async def wait_connected(self) -> None:
        """Wait generic storage is connected."""
        await self._connected_state.wait(True)

    @property
    def is_connected(self) -> bool:
        """Get running state of the storage."""
        return self._is_connected

    async def run(self):
        """Connect storage."""
        await self._backend.connect()
        self._is_connected = True
        self._connected_state.set(True)
        try:
            while True:
                await asyncio.sleep(1)
        finally:
            await self._backend.disconnect()
            self._is_connected = False

    @classmethod
    def _get_backend_instance(cls, uri: str) -> AbstractStorageBackend:
        """Construct backend instance."""
        backend_name = urlparse(uri).scheme
        backend_class = BACKENDS.get(backend_name, None)
        if backend_class is None:
            raise ValueError(
                f"Backend `{backend_name}` is not supported. Supported are {', '.join(BACKENDS.keys())} "
            )
        return backend_class(uri)

    async def get_collection(self, collection_name: str) -> AsyncCollection:
        """Get async collection."""
        await self._backend.ensure_collection(collection_name)
        return AsyncCollection(collection_name=collection_name,
                               storage_backend=self._backend)

    def get_sync_collection(self, collection_name: str) -> SyncCollection:
        """Get sync collection."""
        if not self._loop:  # pragma: nocover
            raise ValueError("Storage not started!")
        return SyncCollection(self.get_collection(collection_name), self._loop)

    def __repr__(self) -> str:
        """Get string representation of the storage."""
        return f"[GenericStorage({self._storage_uri}){'Connected' if self.is_connected else 'Not connected'}]"