async def test_asyncstate_with_list_of_valid_states(): """Test various cases for AsyncState.""" states = [1, 2, 3] state = AsyncState(1, states) state.set(2) assert state.get() == 2 with pytest.raises(ValueError): state.set("anything") assert state.get() == 2
async def test_async_state_transit(): """Test async state transit contextmanager.""" state = AsyncState() state.set(None) with state.transit(initial=1, success=2, fail=3): assert state.get() == 1 assert state.get() == 2 state.set(None) with suppress(ValueError): with state.transit(initial=1, success=2, fail=3): assert state.get() == 1 raise ValueError() assert state.get() == 3
async def test_asyncstate_callback(): """Test various cases for AsyncState.callback.""" state = AsyncState() called = False def callback_err(state): raise Exception("expected") def callback(state): nonlocal called called = True state.add_callback(callback_err) state.add_callback(callback) state.set(2) assert state.get() == 2 assert called
async def test_async_state(): """Test various cases for AsyncState.""" loop = asyncio.get_event_loop() state = AsyncState() # check set/get value = 1 state.set(value) assert state.get() == value # check set/get with state property value = 3 state.state = 3 assert state.state == value # check wait/set loop.call_soon(state.set, 2) await state.wait(2) # state is already set await state.wait(2)
class Connection(Component, ABC): """Abstract definition of a connection.""" connection_id = None # type: PublicId def __init__( self, configuration: ConnectionConfig, data_dir: str, identity: Optional[Identity] = None, crypto_store: Optional[CryptoStore] = None, restricted_to_protocols: Optional[Set[PublicId]] = None, excluded_protocols: Optional[Set[PublicId]] = None, **kwargs: Any, ) -> None: """ Initialize the connection. The configuration must be specified if and only if the following parameters are None: connection_id, excluded_protocols or restricted_to_protocols. :param configuration: the connection configuration. :param data_dir: directory where to put local files. :param identity: the identity object held by the agent. :param crypto_store: the crypto store for encrypted communication. :param restricted_to_protocols: the set of protocols ids of the only supported protocols for this connection. :param excluded_protocols: the set of protocols ids that we want to exclude for this connection. """ enforce(configuration is not None, "The configuration must be provided.") super().__init__(configuration, **kwargs) enforce( super().public_id == self.connection_id, "Connection ids in configuration and class not matching.", ) self._state = AsyncState(ConnectionStates.disconnected) self._identity = identity self._crypto_store = crypto_store self._data_dir = data_dir self._restricted_to_protocols = ( restricted_to_protocols if restricted_to_protocols is not None else set() ) self._excluded_protocols = ( excluded_protocols if excluded_protocols is not None else set() ) @property def loop(self) -> asyncio.AbstractEventLoop: """Get the event loop.""" enforce(asyncio.get_event_loop().is_running(), "Event loop is not running.") return asyncio.get_event_loop() def _ensure_connected(self) -> None: # pragma: nocover """Raise exception if connection is not connected.""" if not self.is_connected: raise ConnectionError("Connection is not connected! Connect first!") @staticmethod def _ensure_valid_envelope_for_external_comms(envelope: "Envelope") -> None: """ Ensure the envelope sender and to are valid addresses for agent-to-agent communication. :param envelope: the envelope """ enforce( not envelope.is_sender_public_id, f"Sender field of envelope is public id, needs to be address. Found={envelope.sender}", ) enforce( not envelope.is_to_public_id, f"To field of envelope is public id, needs to be address. Found={envelope.to}", ) @contextmanager def _connect_context(self) -> Generator: """Set state connecting, disconnecteing, dicsconnected during connect method.""" with self._state.transit( initial=ConnectionStates.connecting, success=ConnectionStates.connected, fail=ConnectionStates.disconnected, ): yield @property def address(self) -> "Address": # pragma: nocover """Get the address.""" if self._identity is None: raise ValueError( "You must provide the identity in order to retrieve the address." ) return self._identity.address @property def crypto_store(self) -> CryptoStore: # pragma: nocover """Get the crypto store.""" if self._crypto_store is None: raise ValueError("CryptoStore not available.") return self._crypto_store @property def has_crypto_store(self) -> bool: # pragma: nocover """Check if the connection has the crypto store.""" return self._crypto_store is not None @property def data_dir(self) -> str: # pragma: nocover """Get the data directory.""" return self._data_dir @property def component_type(self) -> ComponentType: # pragma: nocover """Get the component type.""" return ComponentType.CONNECTION @property def configuration(self) -> ConnectionConfig: """Get the connection configuration.""" if self._configuration is None: # pragma: nocover raise ValueError("Configuration not set.") return cast(ConnectionConfig, super().configuration) @property def restricted_to_protocols(self) -> Set[PublicId]: # pragma: nocover """Get the ids of the protocols this connection is restricted to.""" if self._configuration is None: return self._restricted_to_protocols return self.configuration.restricted_to_protocols @property def excluded_protocols(self) -> Set[PublicId]: # pragma: nocover """Get the ids of the excluded protocols for this connection.""" if self._configuration is None: return self._excluded_protocols return self.configuration.excluded_protocols @property def state(self) -> ConnectionStates: """Get the connection status.""" return self._state.get() @state.setter def state(self, value: ConnectionStates) -> None: """Set the connection status.""" if not isinstance(value, ConnectionStates): raise ValueError(f"Incorrect state: `{value}`") self._state.set(value) @abstractmethod async def connect(self) -> None: """Set up the connection.""" @abstractmethod async def disconnect(self) -> None: """Tear down the connection.""" @abstractmethod async def send(self, envelope: "Envelope") -> None: """ Send an envelope. :param envelope: the envelope to send. :return: None """ @abstractmethod async def receive(self, *args: Any, **kwargs: Any) -> Optional["Envelope"]: """ Receive an envelope. :return: the received envelope, or None if an error occurred. """ @classmethod def from_dir( cls, directory: str, identity: Identity, crypto_store: CryptoStore, data_dir: str, **kwargs: Any, ) -> "Connection": """ Load the connection from a directory. :param directory: the directory to the connection package. :param identity: the identity object. :param crypto_store: object to access the connection crypto objects. :param data_dir: the assets directory. :return: the connection object. """ configuration = cast( ConnectionConfig, load_component_configuration(ComponentType.CONNECTION, Path(directory)), ) configuration.directory = Path(directory) return Connection.from_config( configuration, identity, crypto_store, data_dir, **kwargs ) @classmethod def from_config( cls, configuration: ConnectionConfig, identity: Identity, crypto_store: CryptoStore, data_dir: str, **kwargs: Any, ) -> "Connection": """ Load a connection from a configuration. :param configuration: the connection configuration. :param identity: the identity object. :param crypto_store: object to access the connection crypto objects. :param data_dir: the directory of the AEA project data. :return: an instance of the concrete connection class. """ configuration = cast(ConnectionConfig, configuration) directory = cast(Path, configuration.directory) load_aea_package(configuration) connection_module_path = directory / "connection.py" if not (connection_module_path.exists() and connection_module_path.is_file()): raise AEAComponentLoadException( "Connection module '{}' not found.".format(connection_module_path) ) connection_module = load_module( "connection_module", directory / "connection.py" ) classes = inspect.getmembers(connection_module, inspect.isclass) connection_class_name = cast(str, configuration.class_name) connection_classes = list( filter(lambda x: re.match(connection_class_name, x[0]), classes) ) name_to_class = dict(connection_classes) logger = get_logger(__name__, identity.name) logger.debug("Processing connection {}".format(connection_class_name)) connection_class = name_to_class.get(connection_class_name, None) if connection_class is None: raise AEAComponentLoadException( "Connection class '{}' not found.".format(connection_class_name) ) try: connection = connection_class( configuration=configuration, data_dir=data_dir, identity=identity, crypto_store=crypto_store, **kwargs, ) except Exception as e: # pragma: nocover # pylint: disable=broad-except e_str = parse_exception(e) raise AEAInstantiationException( f"An error occured during instantiation of connection {configuration.public_id}/{configuration.class_name}:\n{e_str}" ) return connection @property def is_connected(self) -> bool: # pragma: nocover """Return is connected state.""" return self.state == ConnectionStates.connected @property def is_connecting(self) -> bool: # pragma: nocover """Return is connecting state.""" return self.state == ConnectionStates.connecting @property def is_disconnected(self) -> bool: # pragma: nocover """Return is disconnected state.""" return self.state == ConnectionStates.disconnected
class Storage(Runnable): """Generic storage.""" def __init__( self, storage_uri: str, loop: asyncio.AbstractEventLoop = None, threaded: bool = False, ) -> None: """ Init stortage. :param storage_uri: configuration string for storage. :param loop: asyncio event loop to use. :param threaded: bool. start in thread if True. :return: None """ super().__init__(loop=loop, threaded=threaded) self._storage_uri = storage_uri self._backend: AbstractStorageBackend = self._get_backend_instance( storage_uri) self._is_connected = False self._connected_state = AsyncState(False) async def wait_connected(self) -> None: """Wait generic storage is connected.""" await self._connected_state.wait(True) @property def is_connected(self) -> bool: """Get running state of the storage.""" return self._is_connected async def run(self): """Connect storage.""" await self._backend.connect() self._is_connected = True self._connected_state.set(True) try: while True: await asyncio.sleep(1) finally: await self._backend.disconnect() self._is_connected = False @classmethod def _get_backend_instance(cls, uri: str) -> AbstractStorageBackend: """Construct backend instance.""" backend_name = urlparse(uri).scheme backend_class = BACKENDS.get(backend_name, None) if backend_class is None: raise ValueError( f"Backend `{backend_name}` is not supported. Supported are {', '.join(BACKENDS.keys())} " ) return backend_class(uri) async def get_collection(self, collection_name: str) -> AsyncCollection: """Get async collection.""" await self._backend.ensure_collection(collection_name) return AsyncCollection(collection_name=collection_name, storage_backend=self._backend) def get_sync_collection(self, collection_name: str) -> SyncCollection: """Get sync collection.""" if not self._loop: # pragma: nocover raise ValueError("Storage not started!") return SyncCollection(self.get_collection(collection_name), self._loop) def __repr__(self) -> str: """Get string representation of the storage.""" return f"[GenericStorage({self._storage_uri}){'Connected' if self.is_connected else 'Not connected'}]"