class Barrier(object): def __init__(self, warning=False): # aliases for functions self.wait_answer = self.wait self.answer = self.go self.go_timestamp = 0 self.tolerance_before_wait = 0.01 # seconds self.warning = False self.barrier_event = Event() def wait(self, timeout=None): #print "starting event wait, timeout: ", timeout success = self.barrier_event.wait(timeout) #print "finished event wait, success: ", success self.barrier_event.clear() return success def go(self): self.go_timestamp = time.time() self.barrier_event.set() def start(self): self.start_time = time.time() def restart(self): self.start() def is_waiting(self): return not self.barrier_event.ready()
def test_bridge(): rsig = Event() ssig = Event() receiver = TaskManager.spawn(Receiver(rsig)) sender = TaskManager.spawn(Sender(ssig)) with receiver.bridge(sender) as bridge: assert TaskManager.count() == 2 # Wait for tasks to finish receiver.wait() sender.wait() # An ensure they set the flags correctly assert rsig.ready() assert ssig.ready() assert TaskManager.count() == 0
class _TaskIOBridge(object): def __init__(self, intask, outtask): self._intask = intask self._outtask = outtask self._in2out = gevent.spawn(self._forward_in2out) self._out2in = gevent.spawn(self._forward_out2in) self._closed = Event() def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.close() def __del__(self): self.close() def _forward_in2out(self): try: for msg in self._intask.output.watch(): #LOG.info("Forwarding %r->%r - %r", self._intask, self._outtask, msg) self._outtask.input.send(msg) finally: self._closed.set() def _forward_out2in(self): try: for msg in self._outtask.output.watch(): #LOG.info("Forwarding %r<-%r - %r", self._intask, self._outtask, msg) self._intask.input.send(msg) finally: self._closed.set() @property def closed(self): return self._closed.ready() def wait(self): self._closed.wait() def close(self): if not self.closed: self._in2out.kill() self._out2in.kill()
class Channel(object): __slots__ = ('_recvq', '_closed', '_mon') def __init__(self): self._mon = Publisher() self._recvq = Queue() self._closed = Event() def __iter__(self): while not self.closed: msg = self.recv() if msg is StopIteration: break yield msg def __len__(self): """ Number of messages in buffer which haven't been delivered to watchers """ return self._recvq.qsize() def __del__(self): self.close() def send(self, msg): self._recvq.put_nowait(msg) if len(self._mon): self.recv() def watch(self): was_first = len(self._mon) == 0 subscriber = self._mon.subscribe() if was_first: self._recvall() return subscriber def _recvall(self): while self._recvq.qsize(): self.recv() def write(self, data): self.send(dict(data=data)) def recv(self): if self.closed: # XXX: raise better exception raise RuntimeError("Closed") msg = self._recvq.get() self._mon.send(msg) if msg is StopIteration: self._closed.set() return msg def wait(self): return self._closed.wait() @property def closed(self): return self._closed.ready() def close(self): if not self.closed: self.send(StopIteration) def datastream(self): return DataStream(self)
class RaidenService(Runnable): """ A Raiden node. """ def __init__( self, chain: BlockChainService, query_start_block: typing.BlockNumber, default_registry: TokenNetworkRegistry, default_secret_registry: SecretRegistry, private_key_bin, transport, raiden_event_handler, config, discovery=None, ): super().__init__() if not isinstance(private_key_bin, bytes) or len(private_key_bin) != 32: raise ValueError('invalid private_key') self.tokennetworkids_to_connectionmanagers = dict() self.identifier_to_results: typing.Dict[typing.PaymentID, AsyncResult, ] = dict() self.chain: BlockChainService = chain self.default_registry = default_registry self.query_start_block = query_start_block self.default_secret_registry = default_secret_registry self.config = config self.privkey = private_key_bin self.address = privatekey_to_address(private_key_bin) self.discovery = discovery self.private_key = PrivateKey(private_key_bin) self.pubkey = self.private_key.public_key.format(compressed=False) self.transport = transport self.blockchain_events = BlockchainEvents() self.alarm = AlarmTask(chain) self.raiden_event_handler = raiden_event_handler self.stop_event = Event() self.stop_event.set() # inits as stopped self.wal = None self.snapshot_group = 0 # This flag will be used to prevent the service from processing # state changes events until we know that pending transactions # have been dispatched. self.dispatch_events_lock = Semaphore(1) self.database_path = config['database_path'] if self.database_path != ':memory:': database_dir = os.path.dirname(config['database_path']) os.makedirs(database_dir, exist_ok=True) self.database_dir = database_dir # Prevent concurrent access to the same db self.lock_file = os.path.join(self.database_dir, '.lock') self.db_lock = filelock.FileLock(self.lock_file) else: self.database_path = ':memory:' self.database_dir = None self.lock_file = None self.serialization_file = None self.db_lock = None self.event_poll_lock = gevent.lock.Semaphore() def start(self): """ Start the node synchronously. Raises directly if anything went wrong on startup """ if not self.stop_event.ready(): raise RuntimeError(f'{self!r} already started') self.stop_event.clear() if self.database_dir is not None: self.db_lock.acquire(timeout=0) assert self.db_lock.is_locked # start the registration early to speed up the start if self.config['transport_type'] == 'udp': endpoint_registration_greenlet = gevent.spawn( self.discovery.register, self.address, self.config['transport']['udp']['external_ip'], self.config['transport']['udp']['external_port'], ) storage = sqlite.SQLiteStorage(self.database_path, serialize.JSONSerializer()) self.wal = wal.restore_to_state_change( transition_function=node.state_transition, storage=storage, state_change_identifier='latest', ) if self.wal.state_manager.current_state is None: log.debug( 'No recoverable state available, created inital state', node=pex(self.address), ) block_number = self.chain.block_number() state_change = ActionInitChain( random.Random(), block_number, self.chain.node_address, self.chain.network_id, ) self.wal.log_and_dispatch(state_change) payment_network = PaymentNetworkState( self.default_registry.address, [], # empty list of token network states as it's the node's startup ) state_change = ContractReceiveNewPaymentNetwork( constants.EMPTY_HASH, payment_network, ) self.handle_state_change(state_change) # On first run Raiden needs to fetch all events for the payment # network, to reconstruct all token network graphs and find opened # channels last_log_block_number = 0 else: # The `Block` state change is dispatched only after all the events # for that given block have been processed, filters can be safely # installed starting from this position without losing events. last_log_block_number = views.block_number( self.wal.state_manager.current_state) log.debug( 'Restored state from WAL', last_restored_block=last_log_block_number, node=pex(self.address), ) known_networks = views.get_payment_network_identifiers( views.state_from_raiden(self)) if known_networks and self.default_registry.address not in known_networks: configured_registry = pex(self.default_registry.address) known_registries = lpex(known_networks) raise RuntimeError( f'Token network address mismatch.\n' f'Raiden is configured to use the smart contract ' f'{configured_registry}, which conflicts with the current known ' f'smart contracts {known_registries}', ) # Clear ref cache & disable caching serialize.RaidenJSONDecoder.ref_cache.clear() serialize.RaidenJSONDecoder.cache_object_references = False # Restore the current snapshot group state_change_qty = self.wal.storage.count_state_changes() self.snapshot_group = state_change_qty // SNAPSHOT_STATE_CHANGES_COUNT # Install the filters using the correct from_block value, otherwise # blockchain logs can be lost. self.install_all_blockchain_filters( self.default_registry, self.default_secret_registry, last_log_block_number, ) # Complete the first_run of the alarm task and synchronize with the # blockchain since the last run. # # Notes about setup order: # - The filters must be polled after the node state has been primed, # otherwise the state changes won't have effect. # - The alarm must complete its first run before the transport is started, # to avoid rejecting messages for unknown channels. self.alarm.register_callback(self._callback_new_block) # alarm.first_run may process some new channel, which would start_health_check_for # a partner, that's why transport needs to be already started at this point self.transport.start(self) self.alarm.first_run() chain_state = views.state_from_raiden(self) # Dispatch pending transactions pending_transactions = views.get_pending_transactions(chain_state, ) log.debug( 'Processing pending transactions', num_pending_transactions=len(pending_transactions), node=pex(self.address), ) with self.dispatch_events_lock: for transaction in pending_transactions: try: self.raiden_event_handler.on_raiden_event( self, transaction) except RaidenRecoverableError as e: log.error(str(e)) except RaidenUnrecoverableError as e: if self.config['network_type'] == NetworkType.MAIN: if isinstance(e, InvalidDBData): raise log.error(str(e)) else: raise self.alarm.start() # after transport and alarm is started, send queued messages events_queues = views.get_all_messagequeues(chain_state) for queue_identifier, event_queue in events_queues.items(): self.start_health_check_for(queue_identifier.recipient) # repopulate identifier_to_results for pending transfers for event in event_queue: if type(event) == SendDirectTransfer: self.identifier_to_results[ event.payment_identifier] = AsyncResult() message = message_from_sendevent(event, self.address) self.sign(message) self.transport.send_async(queue_identifier, message) # exceptions on these subtasks should crash the app and bubble up self.alarm.link_exception(self.on_error) self.transport.link_exception(self.on_error) # Health check needs the transport layer self.start_neighbours_healthcheck() if self.config['transport_type'] == 'udp': endpoint_registration_greenlet.get( ) # re-raise if exception occurred super().start() def _run(self): """ Busy-wait on long-lived subtasks/greenlets, re-raise if any error occurs """ try: self.stop_event.wait() except gevent.GreenletExit: # killed without exception self.stop_event.set() gevent.killall([self.alarm, self.transport]) # kill children raise # re-raise to keep killed status except Exception: self.stop() raise def stop(self): """ Stop the node gracefully. Raise if any stop-time error occurred on any subtask """ if self.stop_event.ready(): # not started return # Needs to come before any greenlets joining self.stop_event.set() # Filters must be uninstalled after the alarm task has stopped. Since # the events are polled by an alarm task callback, if the filters are # uninstalled before the alarm task is fully stopped the callback # `poll_blockchain_events` will fail. # # We need a timeout to prevent an endless loop from trying to # contact the disconnected client try: self.transport.stop() self.alarm.stop() self.transport.get() self.alarm.get() self.blockchain_events.uninstall_all_event_listeners() except gevent.Timeout: pass self.blockchain_events.reset() if self.db_lock is not None: self.db_lock.release() def add_pending_greenlet(self, greenlet: gevent.Greenlet): greenlet.link_exception(self.on_error) def __repr__(self): return '<{} {}>'.format(self.__class__.__name__, pex(self.address)) def start_neighbours_healthcheck(self): for neighbour in views.all_neighbour_nodes( self.wal.state_manager.current_state): if neighbour != ConnectionManager.BOOTSTRAP_ADDR: self.start_health_check_for(neighbour) def get_block_number(self): return views.block_number(self.wal.state_manager.current_state) def handle_state_change(self, state_change): log.debug('STATE CHANGE', node=pex(self.address), state_change=state_change) event_list = self.wal.log_and_dispatch(state_change) if self.dispatch_events_lock.locked(): return [] for event in event_list: log.debug('RAIDEN EVENT', node=pex(self.address), raiden_event=event) try: self.raiden_event_handler.on_raiden_event( raiden=self, event=event, ) except RaidenRecoverableError as e: log.error(str(e)) except RaidenUnrecoverableError as e: if self.config['network_type'] == NetworkType.MAIN: if isinstance(e, InvalidDBData): raise log.error(str(e)) else: raise # Take a snapshot every SNAPSHOT_STATE_CHANGES_COUNT # TODO: Gather more data about storage requirements # and update the value to specify how often we need # capturing a snapshot should take place new_snapshot_group = self.wal.storage.count_state_changes( ) // SNAPSHOT_STATE_CHANGES_COUNT if new_snapshot_group > self.snapshot_group: log.debug(f'Storing snapshot: {new_snapshot_group}') self.wal.snapshot() self.snapshot_group = new_snapshot_group return event_list def set_node_network_state(self, node_address, network_state): state_change = ActionChangeNodeNetworkState(node_address, network_state) self.wal.log_and_dispatch(state_change) def start_health_check_for(self, node_address): self.transport.start_health_check(node_address) def _callback_new_block(self, latest_block): """Called once a new block is detected by the alarm task. Note: This should be called only once per block, otherwise there will be duplicated `Block` state changes in the log. Therefore this method should be called only once a new block is mined with the appropriate block_number argument from the AlarmTask. """ # Raiden relies on blockchain events to update its off-chain state, # therefore some APIs /used/ to forcefully poll for events. # # This was done for APIs which have on-chain side-effects, e.g. # openning a channel, where polling the event is required to update # off-chain state to providing a consistent view to the caller, e.g. # the channel exists after the API call returns. # # That pattern introduced a race, because the events are returned only # once per filter, and this method would be called concurrently by the # API and the AlarmTask. The following lock is necessary, to ensure the # expected side-effects are properly applied (introduced by the commit # 3686b3275ff7c0b669a6d5e2b34109c3bdf1921d) latest_block_number = latest_block['number'] with self.event_poll_lock: for event in self.blockchain_events.poll_blockchain_events( latest_block_number): # These state changes will be procesed with a block_number # which is /larger/ than the ChainState's block_number. on_blockchain_event(self, event) # On restart the Raiden node will re-create the filters with the # ethereum node. These filters will have the from_block set to the # value of the latest Block state change. To avoid missing events # the Block state change is dispatched only after all of the events # have been processed. # # This means on some corner cases a few events may be applied # twice, this will happen if the node crashed and some events have # been processed but the Block state change has not been # dispatched. state_change = Block( block_number=latest_block_number, gas_limit=latest_block['gasLimit'], block_hash=bytes(latest_block['hash']), ) self.handle_state_change(state_change) def sign(self, message): """ Sign message inplace. """ if not isinstance(message, SignedMessage): raise ValueError('{} is not signable.'.format(repr(message))) message.sign(self.private_key) def install_all_blockchain_filters( self, token_network_registry_proxy: TokenNetworkRegistry, secret_registry_proxy: SecretRegistry, from_block: typing.BlockNumber, ): with self.event_poll_lock: node_state = views.state_from_raiden(self) token_networks = views.get_token_network_identifiers( node_state, token_network_registry_proxy.address, ) self.blockchain_events.add_token_network_registry_listener( token_network_registry_proxy, from_block, ) self.blockchain_events.add_secret_registry_listener( secret_registry_proxy, from_block, ) for token_network in token_networks: token_network_proxy = self.chain.token_network(token_network) self.blockchain_events.add_token_network_listener( token_network_proxy, from_block, ) def connection_manager_for_token_network(self, token_network_identifier): if not is_binary_address(token_network_identifier): raise InvalidAddress('token address is not valid.') known_token_networks = views.get_token_network_identifiers( views.state_from_raiden(self), self.default_registry.address, ) if token_network_identifier not in known_token_networks: raise InvalidAddress('token is not registered.') manager = self.tokennetworkids_to_connectionmanagers.get( token_network_identifier) if manager is None: manager = ConnectionManager(self, token_network_identifier) self.tokennetworkids_to_connectionmanagers[ token_network_identifier] = manager return manager def leave_all_token_networks(self): state_change = ActionLeaveAllNetworks() self.wal.log_and_dispatch(state_change) def close_and_settle(self): log.info('raiden will close and settle all channels now') self.leave_all_token_networks() connection_managers = [ cm for cm in self.tokennetworkids_to_connectionmanagers.values() ] if connection_managers: waiting.wait_for_settle_all_channels( self, self.alarm.sleep_time, ) def mediated_transfer_async( self, token_network_identifier: typing.TokenNetworkID, amount: typing.TokenAmount, target: typing.Address, identifier: typing.PaymentID, ): """ Transfer `amount` between this node and `target`. This method will start an asynchronous transfer, the transfer might fail or succeed depending on a couple of factors: - Existence of a path that can be used, through the usage of direct or intermediary channels. - Network speed, making the transfer sufficiently fast so it doesn't expire. """ async_result = self.start_mediated_transfer( token_network_identifier, amount, target, identifier, ) return async_result def direct_transfer_async(self, token_network_identifier, amount, target, identifier): """ Do a direct transfer with target. Direct transfers are non cancellable and non expirable, since these transfers are a signed balance proof with the transferred amount incremented. Because the transfer is non cancellable, there is a level of trust with the target. After the message is sent the target is effectively paid and then it is not possible to revert. The async result will be set to False iff there is no direct channel with the target or the payer does not have balance to complete the transfer, otherwise because the transfer is non expirable the async result *will never be set to False* and if the message is sent it will hang until the target node acknowledge the message. This transfer should be used as an optimization, since only two packets are required to complete the transfer (from the payers perspective), whereas the mediated transfer requires 6 messages. """ self.start_health_check_for(target) if identifier is None: identifier = create_default_identifier() direct_transfer = ActionTransferDirect( token_network_identifier, target, identifier, amount, ) async_result = AsyncResult() self.identifier_to_results[identifier] = async_result self.handle_state_change(direct_transfer) def start_mediated_transfer( self, token_network_identifier: typing.TokenNetworkID, amount: typing.TokenAmount, target: typing.Address, identifier: typing.PaymentID, ): self.start_health_check_for(target) if identifier is None: identifier = create_default_identifier() if identifier in self.identifier_to_results: return self.identifier_to_results[identifier] async_result = AsyncResult() self.identifier_to_results[identifier] = async_result secret = random_secret() init_initiator_statechange = initiator_init( self, identifier, amount, secret, token_network_identifier, target, ) # Dispatch the state change even if there are no routes to create the # wal entry. self.handle_state_change(init_initiator_statechange) return async_result def mediate_mediated_transfer(self, transfer: LockedTransfer): init_mediator_statechange = mediator_init(self, transfer) self.handle_state_change(init_mediator_statechange) def target_mediated_transfer(self, transfer: LockedTransfer): self.start_health_check_for(transfer.initiator) init_target_statechange = target_init(transfer) self.handle_state_change(init_target_statechange)
class UserAddressManager: """ Matrix user <-> eth address mapping and user / address reachability helper. In Raiden the smallest unit of addressability is a node with an associated Ethereum address. In Matrix it's a user. Matrix users are (at the moment) bound to a specific homeserver. Since we want to provide resiliency against unavailable homeservers a single Raiden node with a single Ethereum address can be in control over multiple Matrix users on multiple homeservers. Therefore we need to perform a many-to-one mapping of Matrix users to Ethereum addresses. Each Matrix user has a presence state (ONLINE, OFFLINE). One of the preconditions of running a Raiden node is that there can always only be one node online for a particular address at a time. That means we can synthesize the reachability of an address from the user presence states. This helper internally tracks both the user presence and address reachability for addresses that have been marked as being 'interesting' (by calling the `.add_address()` method). Additionally it provides the option of passing callbacks that will be notified when presence / reachability change. """ def __init__( self, client: GMatrixClient, displayname_cache: DisplayNameCache, address_reachability_changed_callback: Callable[ [Address, AddressReachability], None], user_presence_changed_callback: Optional[Callable[[User, UserPresence], None]] = None, _log_context: Optional[Dict[str, Any]] = None, ) -> None: self._client = client self._displayname_cache = displayname_cache self._address_reachability_changed_callback = address_reachability_changed_callback self._user_presence_changed_callback = user_presence_changed_callback self._stop_event = Event() self._reset_state() self._log_context = _log_context self._log = None self._listener_id: Optional[UUID] = None def start(self) -> None: """ Start listening for presence updates. Should be called before ``.login()`` is called on the underlying client. """ assert self._listener_id is None, "UserAddressManager.start() called twice" self._stop_event.clear() self._listener_id = self._client.add_presence_listener( self._presence_listener) def stop(self) -> None: """ Stop listening on presence updates. """ assert self._listener_id is not None, "UserAddressManager.stop() called before start" self._stop_event.set() self._client.remove_presence_listener(self._listener_id) self._listener_id = None self._log = None self._reset_state() @property def known_addresses(self) -> Set[Address]: """ Return all addresses we keep track of """ # This must return a copy of the current keys, because the container # may be modified while these values are used. Issue: #5240 return set(self._address_to_userids) def is_address_known(self, address: Address) -> bool: """ Is the given ``address`` reachability being monitored? """ return address in self._address_to_userids def add_address(self, address: Address) -> None: """ Add ``address`` to the known addresses that are being observed for reachability. """ # Since _address_to_userids is a defaultdict accessing the key creates the entry _ = self._address_to_userids[address] def add_userid_for_address(self, address: Address, user_id: str) -> None: """ Add a ``user_id`` for the given ``address``. Implicitly adds the address if it was unknown before. """ self._address_to_userids[address].add(user_id) def add_userids_for_address(self, address: Address, user_ids: Iterable[str]) -> None: """ Add multiple ``user_ids`` for the given ``address``. Implicitly adds any addresses if they were unknown before. """ self._address_to_userids[address].update(user_ids) def get_userids_for_address(self, address: Address) -> Set[str]: """ Return all known user ids for the given ``address``. """ if not self.is_address_known(address): return set() return self._address_to_userids[address] def get_userid_presence(self, user_id: str) -> UserPresence: """ Return the current presence state of ``user_id``. """ return self._userid_to_presence.get(user_id, UserPresence.UNKNOWN) def get_address_reachability(self, address: Address) -> AddressReachability: """ Return the current reachability state for ``address``. """ return self._address_to_reachabilitystate.get( address, UNKNOWN_REACHABILITY_STATE).reachability def get_address_reachability_state(self, address: Address) -> ReachabilityState: """ Return the current reachability state for ``address``. """ return self._address_to_reachabilitystate.get( address, UNKNOWN_REACHABILITY_STATE) def force_user_presence(self, user: User, presence: UserPresence) -> None: """ Forcibly set the ``user`` presence to ``presence``. This method is only provided to cover an edge case in our use of the Matrix protocol and should **not** generally be used. """ self._userid_to_presence[user.user_id] = presence def populate_userids_for_address(self, address: Address, force: bool = False) -> None: """ Populate known user ids for the given ``address`` from the server directory. If ``force`` is ``True`` perform the directory search even if there already are known users. """ if force or not self.get_userids_for_address(address): self.add_userids_for_address( address, (user.user_id for user in self._client.search_user_directory( to_normalized_address(address)) if self._validate_userid_signature(user)), ) def track_address_presence( self, address: Address, user_ids: Union[Set[str], FrozenSet[str]] = frozenset() ) -> None: """ Update synthesized address presence state. Triggers callback (if any) in case the state has changed. """ # Is this address already tracked for all given user_ids? state_known = ( self.get_address_reachability_state(address).reachability != AddressReachability.UNKNOWN) no_new_user_ids = user_ids.issubset(self._address_to_userids[address]) if state_known and no_new_user_ids: return # Update presence self.add_userids_for_address(address, user_ids) userids_to_presence = {} for uid in user_ids: presence = self._fetch_user_presence(uid) userids_to_presence[uid] = presence # We assume that this is only used when no presence has been set, # yet. So let's use a presence_update_id that's smaller than the # usual ones, which start at 0. self._set_user_presence(uid, presence, presence_update_id=-1) log.debug( "Fetched user presences", address=to_checksum_address(address), userids_to_presence=userids_to_presence, ) self._maybe_address_reachability_changed(address) def get_reachability_from_matrix( self, user_ids: Iterable[str]) -> AddressReachability: """ Get the current reachability without any side effects Since his does not even do any caching, don't use it for the normal communication between participants in a channel. """ for uid in user_ids: presence = self._fetch_user_presence(uid) if USER_PRESENCE_TO_ADDRESS_REACHABILITY[ presence] == AddressReachability.REACHABLE: return AddressReachability.REACHABLE return AddressReachability.UNREACHABLE def _maybe_address_reachability_changed(self, address: Address) -> None: # A Raiden node may have multiple Matrix users, this happens when # Raiden roams from a Matrix server to another. This loop goes over all # these users and uses the "best" presence. IOW, if there is a single # Matrix user that is reachable, then the Raiden node is considered # reachable. userids = self._address_to_userids[address].copy() composite_presence = { self._userid_to_presence.get(uid) for uid in userids } new_presence = UserPresence.UNKNOWN for presence in UserPresence.__members__.values(): if presence in composite_presence: new_presence = presence break new_address_reachability = USER_PRESENCE_TO_ADDRESS_REACHABILITY[ new_presence] prev_reachability_state = self.get_address_reachability_state(address) if new_address_reachability == prev_reachability_state.reachability: return now = datetime.now() self.log.debug( "Changing address reachability state", address=to_checksum_address(address), prev_state=prev_reachability_state.reachability, state=new_address_reachability, last_change=prev_reachability_state.time, change_after=now - prev_reachability_state.time, ) self._address_to_reachabilitystate[address] = ReachabilityState( new_address_reachability, now) self._address_reachability_changed_callback(address, new_address_reachability) def _presence_listener(self, event: Dict[str, Any], presence_update_id: int) -> None: """ Update cached user presence state from Matrix presence events. Due to the possibility of nodes using accounts on multiple homeservers a composite address state is synthesised from the cached individual user presence states. """ if self._stop_event.ready(): return user_id = event["sender"] if event["type"] != "m.presence" or user_id == self._user_id: return address = address_from_userid(user_id) # Not a user we've whitelisted, skip. This needs to be on the top of # the function so that we don't request they displayname of users that # are not important for the node. The presence is updated for every # user on the first sync, since every Raiden node is a member of a # broadcast room. This can result in thousands requests to the Matrix # server in the first sync which will lead to slow startup times and # presence problems. if address is None or not self.is_address_known(address): return user = self._user_from_id(user_id, event["content"].get("displayname")) if not user: return self._displayname_cache.warm_users([user]) address = self._validate_userid_signature(user) if not address: return self.add_userid_for_address(address, user_id) new_state = UserPresence(event["content"]["presence"]) self._set_user_presence(user_id, new_state, presence_update_id) self._maybe_address_reachability_changed(address) def _reset_state(self) -> None: self._address_to_userids: Dict[Address, Set[str]] = defaultdict(set) self._address_to_reachabilitystate: Dict[Address, ReachabilityState] = dict() self._userid_to_presence: Dict[str, UserPresence] = dict() self._userid_to_presence_update_id: Dict[str, int] = dict() @property def _user_id(self) -> str: user_id = getattr(self._client, "user_id", None) assert user_id, f"{self.__class__.__name__}._user_id accessed before client login" return user_id def _user_from_id(self, user_id: str, display_name: Optional[str] = None) -> Optional[User]: try: return User(self._client.api, user_id, display_name) except ValueError: log.error("Matrix server returned an invalid user_id.") return None def _fetch_user_presence(self, user_id: str) -> UserPresence: try: presence = UserPresence(self._client.get_user_presence(user_id)) except MatrixRequestError: # The following exception will be raised if the local user and the # target user do not have a shared room: # # MatrixRequestError: 403: # {"errcode":"M_FORBIDDEN","error":"You are not allowed to see their presence."} presence = UserPresence.UNKNOWN log.exception("Could not fetch user presence") return presence def _set_user_presence(self, user_id: str, presence: UserPresence, presence_update_id: int) -> None: user = self._user_from_id(user_id) if not user: return # -1 is used in track_address_presence, so we use -2 as a default. if self._userid_to_presence_update_id.get(user_id, -2) >= presence_update_id: # We've already received a more recent presence (or the same one) return old_presence = self._userid_to_presence.get(user_id) if old_presence == presence: # This can happen when force_user_presence is used. For most other # cased the presence_update_id check will return first. return self._userid_to_presence[user_id] = presence self._userid_to_presence_update_id[user_id] = presence_update_id self.log.debug( "Changing user presence state", user_id=user_id, prev_state=old_presence, state=presence, ) if self._user_presence_changed_callback: self._displayname_cache.warm_users([user]) self._user_presence_changed_callback(user, presence) @staticmethod def _validate_userid_signature(user: User) -> Optional[Address]: return validate_userid_signature(user) @property def log(self) -> BoundLoggerLazyProxy: if self._log: return self._log # type: ignore context = self._log_context or {} # Only cache the logger once the user_id becomes available if hasattr(self._client, "user_id"): context["current_user"] = self._user_id context["node"] = node_address_from_userid(self._user_id) bound_log = log.bind(**context) self._log = bound_log return bound_log # Apply the `_log_context` even if the user_id is not yet available return log.bind(**context)
class _Alarm(object): """A :class:`gevent.event.AsyncResult`-like class to wait until the final set time. """ __slots__ = ('time', 'value', 'timer', 'event') #: Implement it to decide to reschedule the awaking time. It's a function #: which takes 2 time arguments. The first argument is new time to set, #: and the second argument is the previously accepted time. Both arguments #: are never ``None``. accept = NotImplemented def __init__(self): self.time = self.value = self.timer = None self.event = Event() def set(self, time, value=None): """Sets the time to awake up. If the time is not accepted, will be ignored and it returns ``False``. Otherwise, returns ``True``. """ if time is None: raise TypeError('use clear() instead of setting none time') elif self.time is not None and not self.accept(time, self.time): # Not accepted. return False self._reset(time, value) delay = time - time_.time() if delay > 0: # Set timer to wake up. self.timer = get_hub().loop.timer(delay) self.timer.start(self.event.set) else: # Wake up immediately. self.event.set() return True def when(self): """When it will be awoken or ``None``.""" return self.time def ready(self): """Whether it has been awoken.""" return self.event.ready() def wait(self, timeout=None): """Waits until the awaking time. It returns the time.""" if self.event.wait(timeout): return self.time def get(self, block=True, timeout=None): """Waits until and gets the awaking time and the value.""" if not block and not self.ready(): raise Timeout if self.event.wait(timeout): return self.time, self.value raise Timeout(timeout) def clear(self): """Discards the schedule for awaking.""" self.event.clear() self._reset(None, None) def _reset(self, time, value): self.time = time self.value = value if self.timer is not None: self.timer.stop() def __nonzero__(self): return self.time is not None
class RaidenService(Runnable): """ A Raiden node. """ def __init__( self, chain: BlockChainService, query_start_block: BlockNumber, default_registry: TokenNetworkRegistry, default_secret_registry: SecretRegistry, transport, raiden_event_handler, message_handler, config, discovery=None, ): super().__init__() self.tokennetworkids_to_connectionmanagers: ConnectionManagerDict = dict() self.targets_to_identifiers_to_statuses: StatusesDict = defaultdict(dict) self.chain: BlockChainService = chain self.default_registry = default_registry self.query_start_block = query_start_block self.default_secret_registry = default_secret_registry self.config = config self.signer: Signer = LocalSigner(self.chain.client.privkey) self.address = self.signer.address self.discovery = discovery self.transport = transport self.blockchain_events = BlockchainEvents() self.alarm = AlarmTask(chain) self.raiden_event_handler = raiden_event_handler self.message_handler = message_handler self.stop_event = Event() self.stop_event.set() # inits as stopped self.wal: Optional[wal.WriteAheadLog] = None self.snapshot_group = 0 # This flag will be used to prevent the service from processing # state changes events until we know that pending transactions # have been dispatched. self.dispatch_events_lock = Semaphore(1) self.contract_manager = ContractManager(config['contracts_path']) self.database_path = config['database_path'] if self.database_path != ':memory:': database_dir = os.path.dirname(config['database_path']) os.makedirs(database_dir, exist_ok=True) self.database_dir = database_dir # Two raiden processes must not write to the same database, even # though the database itself may be consistent. If more than one # nodes writes state changes to the same WAL there are no # guarantees about recovery, this happens because during recovery # the WAL replay can not be deterministic. lock_file = os.path.join(self.database_dir, '.lock') self.db_lock = filelock.FileLock(lock_file) else: self.database_path = ':memory:' self.database_dir = None self.serialization_file = None self.db_lock = None self.event_poll_lock = gevent.lock.Semaphore() self.gas_reserve_lock = gevent.lock.Semaphore() self.payment_identifier_lock = gevent.lock.Semaphore() def start(self): """ Start the node synchronously. Raises directly if anything went wrong on startup """ if not self.stop_event.ready(): raise RuntimeError(f'{self!r} already started') self.stop_event.clear() if self.database_dir is not None: self.db_lock.acquire(timeout=0) assert self.db_lock.is_locked # start the registration early to speed up the start if self.config['transport_type'] == 'udp': endpoint_registration_greenlet = gevent.spawn( self.discovery.register, self.address, self.config['transport']['udp']['external_ip'], self.config['transport']['udp']['external_port'], ) self.maybe_upgrade_db() storage = sqlite.SerializedSQLiteStorage( database_path=self.database_path, serializer=serialize.JSONSerializer(), ) storage.log_run() self.wal = wal.restore_to_state_change( transition_function=node.state_transition, storage=storage, state_change_identifier='latest', ) if self.wal.state_manager.current_state is None: log.debug( 'No recoverable state available, created inital state', node=pex(self.address), ) # On first run Raiden needs to fetch all events for the payment # network, to reconstruct all token network graphs and find opened # channels last_log_block_number = self.query_start_block state_change = ActionInitChain( random.Random(), last_log_block_number, self.chain.node_address, self.chain.network_id, ) self.handle_state_change(state_change) payment_network = PaymentNetworkState( self.default_registry.address, [], # empty list of token network states as it's the node's startup ) state_change = ContractReceiveNewPaymentNetwork( constants.EMPTY_HASH, payment_network, last_log_block_number, ) self.handle_state_change(state_change) else: # The `Block` state change is dispatched only after all the events # for that given block have been processed, filters can be safely # installed starting from this position without losing events. last_log_block_number = views.block_number(self.wal.state_manager.current_state) log.debug( 'Restored state from WAL', last_restored_block=last_log_block_number, node=pex(self.address), ) known_networks = views.get_payment_network_identifiers(views.state_from_raiden(self)) if known_networks and self.default_registry.address not in known_networks: configured_registry = pex(self.default_registry.address) known_registries = lpex(known_networks) raise RuntimeError( f'Token network address mismatch.\n' f'Raiden is configured to use the smart contract ' f'{configured_registry}, which conflicts with the current known ' f'smart contracts {known_registries}', ) # Restore the current snapshot group state_change_qty = self.wal.storage.count_state_changes() self.snapshot_group = state_change_qty // SNAPSHOT_STATE_CHANGES_COUNT # Install the filters using the correct from_block value, otherwise # blockchain logs can be lost. self.install_all_blockchain_filters( self.default_registry, self.default_secret_registry, last_log_block_number, ) # Complete the first_run of the alarm task and synchronize with the # blockchain since the last run. # # Notes about setup order: # - The filters must be polled after the node state has been primed, # otherwise the state changes won't have effect. # - The alarm must complete its first run before the transport is started, # to reject messages for closed/settled channels. self.alarm.register_callback(self._callback_new_block) with self.dispatch_events_lock: self.alarm.first_run(last_log_block_number) chain_state = views.state_from_raiden(self) self._initialize_transactions_queues(chain_state) self._initialize_whitelists(chain_state) self._initialize_payment_statuses(chain_state) # send messages in queue before starting transport, # this is necessary to avoid a race where, if the transport is started # before the messages are queued, actions triggered by it can cause new # messages to be enqueued before these older ones self._initialize_messages_queues(chain_state) # The transport must not ever be started before the alarm task's # `first_run()` has been, because it's this method which synchronizes the # node with the blockchain, including the channel's state (if the channel # is closed on-chain new messages must be rejected, which will not be the # case if the node is not synchronized) self.transport.start( raiden_service=self, message_handler=self.message_handler, prev_auth_data=chain_state.last_transport_authdata, ) # First run has been called above! self.alarm.start() # exceptions on these subtasks should crash the app and bubble up self.alarm.link_exception(self.on_error) self.transport.link_exception(self.on_error) # Health check needs the transport layer self.start_neighbours_healthcheck(chain_state) if self.config['transport_type'] == 'udp': endpoint_registration_greenlet.get() # re-raise if exception occurred log.debug('Raiden Service started', node=pex(self.address)) super().start() def _run(self, *args, **kwargs): # pylint: disable=method-hidden """ Busy-wait on long-lived subtasks/greenlets, re-raise if any error occurs """ try: self.stop_event.wait() except gevent.GreenletExit: # killed without exception self.stop_event.set() gevent.killall([self.alarm, self.transport]) # kill children raise # re-raise to keep killed status except Exception: self.stop() raise def stop(self): """ Stop the node gracefully. Raise if any stop-time error occurred on any subtask """ if self.stop_event.ready(): # not started return # Needs to come before any greenlets joining self.stop_event.set() # Filters must be uninstalled after the alarm task has stopped. Since # the events are polled by an alarm task callback, if the filters are # uninstalled before the alarm task is fully stopped the callback # `poll_blockchain_events` will fail. # # We need a timeout to prevent an endless loop from trying to # contact the disconnected client self.transport.stop() self.alarm.stop() self.transport.join() self.alarm.join() self.blockchain_events.uninstall_all_event_listeners() # Close storage DB to release internal DB lock self.wal.storage.conn.close() if self.db_lock is not None: self.db_lock.release() log.debug('Raiden Service stopped', node=pex(self.address)) def add_pending_greenlet(self, greenlet: gevent.Greenlet): greenlet.link_exception(self.on_error) def __repr__(self): return '<{} {}>'.format(self.__class__.__name__, pex(self.address)) def start_neighbours_healthcheck(self, chain_state: ChainState): for neighbour in views.all_neighbour_nodes(chain_state): if neighbour != ConnectionManager.BOOTSTRAP_ADDR: self.start_health_check_for(neighbour) def get_block_number(self) -> BlockNumber: assert self.wal return views.block_number(self.wal.state_manager.current_state) def on_message(self, message: Message): self.message_handler.on_message(self, message) def handle_state_change(self, state_change: StateChange): assert self.wal log.debug( 'State change', node=pex(self.address), state_change=_redact_secret(serialize.JSONSerializer.serialize(state_change)), ) old_state = views.state_from_raiden(self) event_list = self.wal.log_and_dispatch(state_change) current_state = views.state_from_raiden(self) for balance_proof in views.detect_balance_proof_change(old_state, current_state): event_list.append(EventNewBalanceProofReceived(balance_proof)) if self.dispatch_events_lock.locked(): return [] for event in event_list: log.debug( 'Raiden event', node=pex(self.address), raiden_event=_redact_secret(serialize.JSONSerializer.serialize(event)), ) try: self.raiden_event_handler.on_raiden_event( raiden=self, event=event, ) except RaidenRecoverableError as e: log.error(str(e)) except InvalidDBData: raise except RaidenUnrecoverableError as e: log_unrecoverable = ( self.config['environment_type'] == Environment.PRODUCTION and not self.config['unrecoverable_error_should_crash'] ) if log_unrecoverable: log.error(str(e)) else: raise # Take a snapshot every SNAPSHOT_STATE_CHANGES_COUNT # TODO: Gather more data about storage requirements # and update the value to specify how often we need # capturing a snapshot should take place new_snapshot_group = self.wal.storage.count_state_changes() // SNAPSHOT_STATE_CHANGES_COUNT if new_snapshot_group > self.snapshot_group: log.debug('Storing snapshot', snapshot_id=new_snapshot_group) self.wal.snapshot() self.snapshot_group = new_snapshot_group return event_list def set_node_network_state(self, node_address: Address, network_state: str): state_change = ActionChangeNodeNetworkState(node_address, network_state) self.handle_state_change(state_change) def start_health_check_for(self, node_address: Address): # This function is a noop during initialization. It can be called # through the alarm task while polling for new channel events. The # healthcheck will be started by self.start_neighbours_healthcheck() if self.transport: self.transport.start_health_check(node_address) def _callback_new_block(self, latest_block: Dict): """Called once a new block is detected by the alarm task. Note: This should be called only once per block, otherwise there will be duplicated `Block` state changes in the log. Therefore this method should be called only once a new block is mined with the corresponding block data from the AlarmTask. """ # User facing APIs, which have on-chain side-effects, force polled the # blockchain to update the node's state. This force poll is used to # provide a consistent view to the user, e.g. a channel open call waits # for the transaction to be mined and force polled the event to update # the node's state. This pattern introduced a race with the alarm task # and the task which served the user request, because the events are # returned only once per filter. The lock below is to protect against # these races (introduced by the commit # 3686b3275ff7c0b669a6d5e2b34109c3bdf1921d) with self.event_poll_lock: latest_block_number = latest_block['number'] confirmation_blocks = self.config['blockchain']['confirmation_blocks'] confirmed_block_number = latest_block_number - confirmation_blocks confirmed_block = self.chain.client.web3.eth.getBlock(confirmed_block_number) # handle testing private chains confirmed_block_number = max(GENESIS_BLOCK_NUMBER, confirmed_block_number) for event in self.blockchain_events.poll_blockchain_events(confirmed_block_number): # These state changes will be procesed with a block_number # which is /larger/ than the ChainState's block_number. on_blockchain_event(self, event) # On restart the Raiden node will re-create the filters with the # ethereum node. These filters will have the from_block set to the # value of the latest Block state change. To avoid missing events # the Block state change is dispatched only after all of the events # have been processed. # # This means on some corner cases a few events may be applied # twice, this will happen if the node crashed and some events have # been processed but the Block state change has not been # dispatched. state_change = Block( block_number=confirmed_block_number, gas_limit=confirmed_block['gasLimit'], block_hash=BlockHash(bytes(confirmed_block['hash'])), ) self.handle_state_change(state_change) def _initialize_transactions_queues(self, chain_state: ChainState): pending_transactions = views.get_pending_transactions(chain_state) log.debug( 'Processing pending transactions', num_pending_transactions=len(pending_transactions), node=pex(self.address), ) with self.dispatch_events_lock: for transaction in pending_transactions: try: self.raiden_event_handler.on_raiden_event(self, transaction) except RaidenRecoverableError as e: log.error(str(e)) except InvalidDBData: raise except RaidenUnrecoverableError as e: log_unrecoverable = ( self.config['environment_type'] == Environment.PRODUCTION and not self.config['unrecoverable_error_should_crash'] ) if log_unrecoverable: log.error(str(e)) else: raise def _initialize_payment_statuses(self, chain_state: ChainState): """ Re-initialize targets_to_identifiers_to_statuses. """ with self.payment_identifier_lock: for task in chain_state.payment_mapping.secrethashes_to_task.values(): if not isinstance(task, InitiatorTask): continue # Every transfer in the transfers_list must have the same target # and payment_identifier, so using the first transfer is # sufficient. initiator = next(iter(task.manager_state.initiator_transfers.values())) transfer = initiator.transfer target = transfer.target identifier = transfer.payment_identifier balance_proof = transfer.balance_proof self.targets_to_identifiers_to_statuses[target][identifier] = PaymentStatus( payment_identifier=identifier, amount=transfer.lock.amount, token_network_identifier=balance_proof.token_network_identifier, payment_done=AsyncResult(), ) def _initialize_messages_queues(self, chain_state: ChainState): """ Push the message queues to the transport. """ events_queues = views.get_all_messagequeues(chain_state) for queue_identifier, event_queue in events_queues.items(): self.start_health_check_for(queue_identifier.recipient) for event in event_queue: message = message_from_sendevent(event, self.address) self.sign(message) self.transport.send_async(queue_identifier, message) def _initialize_whitelists(self, chain_state: ChainState): """ Whitelist neighbors and mediated transfer targets on transport """ for neighbour in views.all_neighbour_nodes(chain_state): if neighbour == ConnectionManager.BOOTSTRAP_ADDR: continue self.transport.whitelist(neighbour) events_queues = views.get_all_messagequeues(chain_state) for event_queue in events_queues.values(): for event in event_queue: if isinstance(event, SendLockedTransfer): transfer = event.transfer if transfer.initiator == self.address: self.transport.whitelist(address=transfer.target) def sign(self, message: Message): """ Sign message inplace. """ if not isinstance(message, SignedMessage): raise ValueError('{} is not signable.'.format(repr(message))) message.sign(self.signer) def install_all_blockchain_filters( self, token_network_registry_proxy: TokenNetworkRegistry, secret_registry_proxy: SecretRegistry, from_block: BlockNumber, ): with self.event_poll_lock: node_state = views.state_from_raiden(self) token_networks = views.get_token_network_identifiers( node_state, token_network_registry_proxy.address, ) self.blockchain_events.add_token_network_registry_listener( token_network_registry_proxy=token_network_registry_proxy, contract_manager=self.contract_manager, from_block=from_block, ) self.blockchain_events.add_secret_registry_listener( secret_registry_proxy=secret_registry_proxy, contract_manager=self.contract_manager, from_block=from_block, ) for token_network in token_networks: token_network_proxy = self.chain.token_network( TokenNetworkAddress(token_network), ) self.blockchain_events.add_token_network_listener( token_network_proxy=token_network_proxy, contract_manager=self.contract_manager, from_block=from_block, ) def connection_manager_for_token_network( self, token_network_identifier: TokenNetworkID, ) -> ConnectionManager: if not is_binary_address(token_network_identifier): raise InvalidAddress('token address is not valid.') known_token_networks = views.get_token_network_identifiers( views.state_from_raiden(self), self.default_registry.address, ) if token_network_identifier not in known_token_networks: raise InvalidAddress('token is not registered.') manager = self.tokennetworkids_to_connectionmanagers.get(token_network_identifier) if manager is None: manager = ConnectionManager(self, token_network_identifier) self.tokennetworkids_to_connectionmanagers[token_network_identifier] = manager return manager def mediated_transfer_async( self, token_network_identifier: TokenNetworkID, amount: PaymentAmount, target: TargetAddress, identifier: PaymentID, secret: Secret = None, secret_hash: SecretHash = None, ) -> PaymentStatus: """ Transfer `amount` between this node and `target`. This method will start an asynchronous transfer, the transfer might fail or succeed depending on a couple of factors: - Existence of a path that can be used, through the usage of direct or intermediary channels. - Network speed, making the transfer sufficiently fast so it doesn't expire. """ if secret is None: secret = random_secret() payment_status = self.start_mediated_transfer_with_secret( token_network_identifier, amount, target, identifier, secret, secret_hash, ) return payment_status def start_mediated_transfer_with_secret( self, token_network_identifier: TokenNetworkID, amount: PaymentAmount, target: TargetAddress, identifier: PaymentID, secret: Secret, secret_hash: SecretHash = None, ) -> PaymentStatus: if secret_hash is None: secret_hash = sha3(secret) # LEFTODO: Supply a proper block id secret_registered = self.default_secret_registry.check_registered( secrethash=secret_hash, block_identifier='latest', ) if secret_registered: raise RaidenUnrecoverableError( f'Attempted to initiate a locked transfer with secrethash {pex(secret_hash)}.' f' That secret is already registered onchain.', ) self.start_health_check_for(Address(target)) if identifier is None: identifier = create_default_identifier() with self.payment_identifier_lock: payment_status = self.targets_to_identifiers_to_statuses[target].get(identifier) if payment_status: payment_status_matches = payment_status.matches( token_network_identifier, amount, ) if not payment_status_matches: raise PaymentConflict( 'Another payment with the same id is in flight', ) return payment_status payment_status = PaymentStatus( payment_identifier=identifier, amount=amount, token_network_identifier=token_network_identifier, payment_done=AsyncResult(), secret=secret, secret_hash=secret_hash, ) self.targets_to_identifiers_to_statuses[target][identifier] = payment_status init_initiator_statechange = initiator_init( raiden=self, transfer_identifier=identifier, transfer_amount=amount, transfer_secret=secret, token_network_identifier=token_network_identifier, target_address=target, ) # Dispatch the state change even if there are no routes to create the # wal entry. self.handle_state_change(init_initiator_statechange) return payment_status def mediate_mediated_transfer(self, transfer: LockedTransfer): init_mediator_statechange = mediator_init(self, transfer) self.handle_state_change(init_mediator_statechange) def target_mediated_transfer(self, transfer: LockedTransfer): self.start_health_check_for(transfer.initiator) init_target_statechange = target_init(transfer) self.handle_state_change(init_target_statechange) def maybe_upgrade_db(self): manager = UpgradeManager(db_filename=self.database_path) manager.run()
class AceClient(object): def __init__(self, clientcounter, ace, connect_timeout=5, result_timeout=10): # Telnet socket response buffer self._recvbuffer = None # Telnet socket response buffer read timeout self._recvbuffertimeout = 30 # AceEngine socket self._socket = None # ClientCounter self._clientcounter = clientcounter # AceEngine read result timeout self._resulttimeout = float(result_timeout) # AceEngine product key self._product_key = None # Result (Created with AsyncResult() on call) self._auth = AsyncResult() # Result for START URL self._url = AsyncResult() # Response time from AceEngine to get URL or DATA self._videotimeout = None # Result for CID self._cid = AsyncResult() # Result fo LOADASYNC self._loadasync = AsyncResult() # Current STATUS self._status = AsyncResult() # Current EVENT self._event = AsyncResult() # Current STATE self._state = AsyncResult() # Current AUTH self._gender = None self._age = None # Seekback seconds. self._seekback = None # Did we get START command again? For seekback. self._started_again = Event() try: self._socket = Telnet(ace['aceHostIP'], ace['aceAPIport'], connect_timeout) logging.debug('Successfully connected to AceStream on %s:%s' % (ace['aceHostIP'], ace['aceAPIport'])) except: errmsg = 'The are no alive AceStream Engines found!' raise AceException(errmsg) else: gevent.spawn(self._recvData) # Spawning telnet data reader def destroy(self): ''' AceClient Destructor ''' # Trying to disconnect logging.debug('Destroying AceStream client.....') try: self._write(AceMessage.request.SHUTDOWN) except: pass # Ignore exceptions on destroy def reset(self): ''' Reset initial values ''' self._started_again.clear() self._url.set() self._loadasync.set() self._cid.set() def _write(self, message): try: logging.debug('>>> %s' % message) self._socket.write('%s\r\n' % message) except gevent.socket.error as e: raise AceException('Telnet exception at socket write %s' % repr(e)) def aceInit(self, gender=AceConst.SEX_MALE, age=AceConst.AGE_25_34, product_key=None, videoseekback=0, videotimeout=30): self._gender = gender self._age = age self._product_key = product_key self._seekback = videoseekback self._videotimeout = float(videotimeout) self._started_again.clear() self._auth = AsyncResult() self._write(AceMessage.request.HELLO) # Sending HELLOBG try: params = self._auth.get(timeout=self._resulttimeout) except gevent.Timeout: errmsg = 'Engine response time %ssec exceeded. HELLOTS not resived!' % self._resulttimeout raise AceException(errmsg) return self._auth = AsyncResult() self._write( AceMessage.request.READY(params.get('key', ''), self._product_key)) try: if self._auth.get( timeout=self._resulttimeout ) == 'NOTREADY': # Get NOTREADY instead AUTH user_auth_level errmsg = 'NOTREADY recived from AceEngine! Wrong acekey?' raise AceException(errmsg) return except gevent.Timeout: errmsg = 'Engine response time %ssec exceeded. AUTH not resived!' % self._resulttimeout raise AceException(errmsg) if int(params.get('version_code', 0)) >= 3003600: # Display download_stopped massage params_dict = {'use_stop_notifications': '1'} self._write(AceMessage.request.SETOPTIONS(params_dict)) def START(self, command, paramsdict, acestreamtype): ''' Start video method Return the url provided by AceEngine ''' paramsdict['stream_type'] = ' '.join( ['{}={}'.format(k, v) for k, v in acestreamtype.items()]) self._url = AsyncResult() self._write(AceMessage.request.START(command.upper(), paramsdict)) try: return self._url.get( timeout=self._videotimeout) # Get url for play from AceEngine except gevent.Timeout: errmsg = 'Engine response time %ssec exceeded. START URL not resived!' % self._videotimeout raise AceException(errmsg) def STOP(self): ''' Stop video method ''' self._state = AsyncResult() self._write(AceMessage.request.STOP) try: self._state.get(timeout=self._resulttimeout) self._started_again.clear() except gevent.Timeout: errmsg = 'Engine response time %ssec exceeded. STATE 0 (IDLE) not resived!' % self._resulttimeout raise AceException(errmsg) def LOADASYNC(self, command, params): self._loadasync = AsyncResult() self._write( AceMessage.request.LOADASYNC(command.upper(), random.randint(1, 100000), params)) try: return self._loadasync.get( timeout=self._resulttimeout) # Get _contentinfo json except gevent.Timeout: errmsg = 'Engine response %ssec time exceeded. LOADARESP not resived!' % self._resulttimeout raise AceException(errmsg) def GETCONTENTINFO(self, command, value): paramsdict = { command: value, 'developer_id': '0', 'affiliate_id': '0', 'zone_id': '0' } return self.LOADASYNC(command, paramsdict) def GETCID(self, command, value): contentinfo = self.GETCONTENTINFO(command, value) if contentinfo['status'] in (1, 2): paramsdict = { 'checksum': contentinfo['checksum'], 'infohash': contentinfo['infohash'], 'developer_id': '0', 'affiliate_id': '0', 'zone_id': '0' } self._cid = AsyncResult() self._write(AceMessage.request.GETCID(paramsdict)) try: return self._cid.get(timeout=self._resulttimeout)[2:] # ##CID except gevent.Timeout: errmsg = 'Engine response time %ssec exceeded. CID not resived!' % self._resulttimeout raise AceException(errmsg) else: errmsg = 'LOADASYNC returned error with message: %s' % contentinfo[ 'message'] raise AceException(errmsg) def GETINFOHASH(self, command, value, idx=0): contentinfo = self.GETCONTENTINFO(command, value) if contentinfo['status'] in (1, 2): return contentinfo['infohash'], [ x[0] for x in contentinfo['files'] if x[1] == int(idx) ][0] elif contentinfo['status'] == 0: errmsg = 'LOADASYNC returned status 0: The transport file does not contain audio/video files' raise AceException(errmsg) else: errmsg = 'LOADASYNC returned error with message: %s' % contentinfo[ 'message'] raise AceException(errmsg) def AceStreamReader(self, url, cid, req_headers=None): ''' Get stream from AceEngine url and write it to client(s) ''' logging.debug('Start StreamReader for url: %s' % url) with requests.Session() as session: if req_headers: session.headers.update(req_headers) logging.debug('Sending headers from client to AceEngine: %s' % session.headers) try: self._write(AceMessage.request.EVENT('play')) # AceEngine return link for HLS stream if url.endswith('.m3u8'): _used_chunks = [] while self._state.get( timeout=self._resulttimeout)[0] in ('2', '3'): for line in session.get( url, stream=True, timeout=(5, self._videotimeout)).iter_lines(): if line.startswith(b'download not found'): return if line.startswith( b'http://') and line not in _used_chunks: self.RAWDataReader( session.get(line, stream=True, timeout=(5, self._videotimeout)), self._clientcounter.getClientsList(cid)) _used_chunks.append(line) if len(_used_chunks) > 15: _used_chunks.pop(0) # AceStream return link for HTTP stream else: self.RAWDataReader( session.get(url, stream=True, timeout=(5, self._videotimeout)), self._clientcounter.getClientsList(cid)) except Exception as err: clients = self._clientcounter.getClientsList(cid) if clients: logging.error('"%s" StreamReader error: %s' % (clients[0].channelName, repr(err))) gevent.wait([ gevent.spawn(self.write_chunk, c, b'', True) for c in clients ]) #b'0\r\n\r\n' - send the chunked trailer finally: _used_chunks = None def RAWDataReader(self, stream, clients): for chunk in stream.iter_content( chunk_size=1048576 if 'Content-Length' in stream.headers else None): if chunk: gevent.wait([ gevent.spawn(self.write_chunk, c, chunk) for c in clients ]) def write_chunk(self, client, chunk, chunk_trailer=None): try: client.out.write( b'%X\r\n%s\r\n' % (len(chunk), chunk)) if not client.transcoder else client.out.write(chunk) except: client.destroy() # Client disconected if chunk_trailer: client.destroy() def _recvData(self): ''' Data receiver method for greenlet ''' while 1: try: with gevent.timeout.Timeout(self._recvbuffertimeout): self._recvbuffer = self._socket.read_until( '\r\n', timeout=None).strip() except EOFError as e: # if the connection is closed and no cooked data is available. raise AceException( 'Telnet exception at socket read. AceClient destroyed %s' % repr(e)) return # Ignore error occurs while reading blank lines from socket in STATE 0 (IDLE) except gevent.socket.timeout: pass # SHUTDOWN socket connection if AceEngine STATE 0 (IDLE) and we didn't read anything from socket until Nsec except gevent.timeout.Timeout: self.destroy() self._clientcounter.idleAce = None else: # Parsing everything only if the string is not empty logging.debug('<<< %s' % requests.compat.unquote(self._recvbuffer)) # HELLOTS if self._recvbuffer.startswith('HELLOTS'): # version=engine_version version_code=version_code key=request_key http_port=http_port self._auth.set({ k: v for k, v in (x.split('=') for x in self._recvbuffer.split() if '=' in x) }) # NOTREADY elif self._recvbuffer.startswith('NOTREADY'): self._auth.set('NOTREADY') # AUTH elif self._recvbuffer.startswith('AUTH'): self._auth.set( self._recvbuffer.split()[1]) # user_auth_level # START elif self._recvbuffer.startswith('START'): # url [ad=1 [interruptable=1]] [stream=1] [pos=position] params = { k: v for k, v in (x.split('=') for x in self._recvbuffer.split() if '=' in x) } if not self._seekback or self._started_again.ready( ) or params.get('stream', '') is not '1': # If seekback is disabled, we use link in first START command. # If seekback is enabled, we wait for first START command and # ignore it, then do seekback in first EVENT position command # AceStream sends us STOP and START again with new link. # We use only second link then. self._url.set( self._recvbuffer.split()[1]) # url for play # LOADRESP elif self._recvbuffer.startswith('LOADRESP'): self._loadasync.set( requests.compat.json.loads( requests.compat.unquote(''.join( self._recvbuffer.split()[2:])))) # STATE elif self._recvbuffer.startswith( 'STATE'): # tuple of (state_id, time of appearance) self._state.set( (self._recvbuffer.split()[1], gevent.time.time())) # STATUS elif self._recvbuffer.startswith('STATUS'): self._tempstatus = self._recvbuffer.split()[1] if self._tempstatus.startswith('main:idle'): pass elif self._tempstatus.startswith('main:loading'): pass elif self._tempstatus.startswith('main:starting'): pass elif self._tempstatus.startswith('main:check'): pass elif self._tempstatus.startswith('main:wait'): pass elif self._tempstatus.startswith( ('main:prebuf', 'main:buf')): pass #progress;time #values = list(map(int, self._tempstatus.split(';')[3:])) #self._status.set({k: v for k, v in zip(AceConst.STATUS, values)}) elif self._tempstatus.startswith('main:dl'): pass #values = list(map(int, self._tempstatus.split(';')[1:])) #self._status.set({k: v for k, v in zip(AceConst.STATUS, values)}) elif self._tempstatus.startswith( 'main:err'): # err;error_id;error_message self._status.set_exception( AceException('%s with message %s' % (self._tempstatus.split(';')[0], self._tempstatus.split(';')[2]))) # CID elif self._recvbuffer.startswith('##'): self._cid.set(self._recvbuffer) # INFO elif self._recvbuffer.startswith('INFO'): pass # EVENT elif self._recvbuffer.startswith('EVENT'): self._tempevent = self._recvbuffer.split() if self._seekback and not self._started_again.ready( ) and 'livepos' in self._tempevent: params = { k: v for k, v in (x.split('=') for x in self._tempevent if '=' in x) } self._write( AceMessage.request.LIVESEEK( int(params['last']) - self._seekback)) self._started_again.set() elif 'getuserdata' in self._tempevent: self._write( AceMessage.request.USERDATA( self._gender, self._age)) elif 'cansave' in self._tempevent: pass elif 'showurl' in self._tempevent: pass elif 'download_stopped' in self._tempevent: pass # PAUSE elif self._recvbuffer.startswith('PAUSE'): self._write(AceMessage.request.EVENT('pause')) # RESUME elif self._recvbuffer.startswith('RESUME'): self._write(AceMessage.request.EVENT('play')) # STOP elif self._recvbuffer.startswith('STOP'): pass # SHUTDOWN elif self._recvbuffer.startswith('SHUTDOWN'): self._socket.close() logging.debug('AceClient destroyed') return finally: gevent.sleep()
class AceClient(object): def __init__(self, ace, connect_timeout=5, result_timeout=10): # Telnet response buffer self._recvbuffer = None # AceEngine socket self._socket = None # AceEngine read result timeout self._resulttimeout = float(result_timeout) # Shutting down flag self._shuttingDown = Event() # AceEngine product key self._product_key = None # Result (Created with AsyncResult() on call) self._result = AsyncResult() # Result for START URL self._urlresult = AsyncResult() # URL response time from AceEngine self._videotimeout = None # Result for CID self._cidresult = AsyncResult() # Current STATUS self._status = AsyncResult() # Current EVENT self._event = AsyncResult() # Current STATE self._state = AsyncResult() # Current AUTH self._gender = None self._age = None # Seekback seconds. self._seekback = None # Did we get START command again? For seekback. self._started_again = Event() # AceEngine Streamreader ring buffer with max number of chunks in queue self._streamReaderQueue = gevent.queue.Queue(maxsize=1000) # Logger logger = logging.getLogger('AceClient') # Try to connect AceStream engine try: self._socket = Telnet(ace['aceHostIP'], ace['aceAPIport'], connect_timeout) logger.debug('Successfully connected to AceStream on %s:%s' % (ace['aceHostIP'], ace['aceAPIport'])) except: errmsg = 'The are no alive AceStream Engines found!' raise AceException(errmsg) else: # Spawning telnet data reader greenlet gevent.spawn(self._recvData) def destroy(self): ''' AceClient Destructor ''' logger = logging.getLogger('AceClient_destroy') # Logger if self._shuttingDown.ready(): return # Already in the middle of destroying self._result.set() # Trying to disconnect try: logger.debug('Destroying AceStream client.....') self._shuttingDown.set() self._write(AceMessage.request.SHUTDOWN) except: pass # Ignore exceptions on destroy finally: self._shuttingDown.set() def reset(self): self._started_again.clear() self._result.set() self._urlresult.set() def _write(self, message): try: logger = logging.getLogger('AceClient_write') logger.debug('>>> %s' % message) self._socket.write('%s\r\n' % message) except EOFError as e: raise AceException('Write error! %s' % repr(e)) def aceInit(self, gender=AceConst.SEX_MALE, age=AceConst.AGE_25_34, product_key=None, videoseekback=0, videotimeout=0): self._gender = gender self._age = age self._product_key = product_key self._seekback = videoseekback self._videotimeout = float(videotimeout) self._started_again.clear() logger = logging.getLogger('AceClient_aceInit') self._result = AsyncResult() self._write(AceMessage.request.HELLO) # Sending HELLOBG try: params = self._result.get(timeout=self._resulttimeout) except gevent.Timeout: errmsg = 'Engine response time %ssec exceeded. HELLOTS not resived!' % self._resulttimeout raise AceException(errmsg) return self._result = AsyncResult() self._write(AceMessage.request.READY(params.get('key',''), self._product_key)) try: if self._result.get(timeout=self._resulttimeout) == 'NOTREADY': # Get NOTREADY instead AUTH user_auth_level errmsg = 'NOTREADY recived from AceEngine! Wrong acekey?' raise AceException(errmsg) return except gevent.Timeout: errmsg = 'Engine response time %ssec exceeded. AUTH not resived!' % self._resulttimeout raise AceException(errmsg) if int(params.get('version_code', 0)) >= 3003600: # Display download_stopped massage params_dict = {'use_stop_notifications': '1'} self._write(AceMessage.request.SETOPTIONS(params_dict)) def START(self, command, paramsdict, acestreamtype): ''' Start video method Returns the url provided by AceEngine ''' paramsdict['stream_type'] = ' '.join(['{}={}'.format(k,v) for k,v in acestreamtype.items()]) self._urlresult = AsyncResult() self._write(AceMessage.request.START(command.upper(), paramsdict)) try: return self._urlresult.get(timeout=self._videotimeout) # Get url for play from AceEngine except gevent.Timeout: errmsg = 'Engine response time %ssec exceeded. START URL not resived!' % self._videotimeout raise AceException(errmsg) def STOP(self): ''' Stop video method ''' self._state = AsyncResult() self._write(AceMessage.request.STOP) try: self._state.get(timeout=self._resulttimeout); self._started_again.clear() except gevent.Timeout: errmsg = 'Engine response time %ssec exceeded. STATE 0 (IDLE) not resived!' % self._resulttimeout raise AceException(errmsg) def LOADASYNC(self, command, params): self._result = AsyncResult() self._write(AceMessage.request.LOADASYNC(command.upper(), random.randint(1, 100000), params)) try: return self._result.get(timeout=self._resulttimeout) # Get _contentinfo json except gevent.Timeout: errmsg = 'Engine response %ssec time exceeded. LOADARESP not resived!' % self._resulttimeout raise AceException(errmsg) def GETCONTENTINFO(self, command, value): paramsdict = { command:value, 'developer_id':'0', 'affiliate_id':'0', 'zone_id':'0' } return self.LOADASYNC(command, paramsdict) def GETCID(self, command, value): contentinfo = self.GETCONTENTINFO(command, value) if contentinfo['status'] in (1, 2): paramsdict = {'checksum':contentinfo['checksum'], 'infohash':contentinfo['infohash'], 'developer_id':'0', 'affiliate_id':'0', 'zone_id':'0'} self._cidresult = AsyncResult() self._write(AceMessage.request.GETCID(paramsdict)) try: cid = self._cidresult.get(timeout=self._resulttimeout) return '' if cid is None or cid == '' else cid[2:] except gevent.Timeout: errmsg = 'Engine response time %ssec exceeded. CID not resived!' % self._resulttimeout raise AceException(errmsg) else: cid = None errmsg = 'LOADASYNC returned error with message: %s' % contentinfo['message'] raise AceException(errmsg) def GETINFOHASH(self, command, value, idx=0): contentinfo = self.GETCONTENTINFO(command, value) if contentinfo['status'] in (1, 2): return contentinfo['infohash'], [x[0] for x in contentinfo['files'] if x[1] == int(idx)][0] elif contentinfo['status'] == 0: errmsg = 'LOADASYNC returned status 0: The transport file does not contain audio/video files' raise AceException(errmsg) else: errmsg = 'LOADASYNC returned error with message: %s' % contentinfo['message'] raise AceException(errmsg) def StreamReader(self, url, cid, counter, req_headers=None): logger = logging.getLogger('StreamReader') logger.debug('Start StreamReader for url: %s' % url) self._write(AceMessage.request.EVENT('play')) with requests.Session() as session: if req_headers: logger.debug('Sending headers from client to AceEngine: %s' % req_headers) session.headers.update(req_headers) try: # AceEngine return link for HLS stream if url.endswith('.m3u8'): _used_chunks = [] while self._state.get(timeout=self._resulttimeout)[0] in ('2', '3'): for line in session.get(url, stream=True, timeout = (5,None)).iter_lines(): if self._state.get(timeout=self._resulttimeout)[0] not in ('2', '3'): return if line.startswith(b'http://') and line not in _used_chunks: self.RAWDataReader(session.get(line, stream=True, timeout=(5,None)).raw, cid, counter) _used_chunks.append(line) if len(_used_chunks) > 15: _used_chunks.pop(0) gevent.sleep(4) # AceStream return link for HTTP stream else: self.RAWDataReader(session.get(url, stream=True, timeout = (5,None)).raw, cid, counter) except requests.exceptions.HTTPError as err: logger.error('An http error occurred while connecting to aceengine: %s' % repr(err)) except requests.exceptions.RequestException: logger.error('There was an ambiguous exception that occurred while handling request') except Exception as err: logger.error('Unexpected error in streamreader %s' % repr(err)) finally: _used_chunks = None self._streamReaderQueue.queue.clear() counter.deleteAll(cid) def RAWDataReader(self, stream, cid, counter): logger = logging.getLogger('RAWDataReader') while self._state.get(timeout=self._resulttimeout)[0] in ('2', '3'): gevent.sleep() if self._state.get(timeout=self._resulttimeout)[0] == '2': # Read data from AceEngine only if STATE 2 (DOWNLOADING) data = stream.read(requests.models.CONTENT_CHUNK_SIZE) if not data: return try: self._streamReaderQueue.put_nowait(data) except gevent.queue.Full: self._streamReaderQueue.get_nowait(); self._streamReaderQueue.put_nowait(data) clients = counter.getClients(cid) if not clients: return for c in clients: try: c.queue.put(data, timeout=5) except gevent.queue.Full: if len(clients) > 1: logger.warning('Client %s does not read data from buffer until 5sec - disconnect it' % c.handler.clientip) c.destroy() elif (time.time() - self._state.get(timeout=self._resulttimeout)[1]) >= self._videotimeout: # STATE 3 (BUFFERING) logger.warning('No data received from AceEngine for %ssec - broadcast stoped' % self._videotimeout); return def _recvData(self): ''' Data receiver method for greenlet ''' logger = logging.getLogger('AceClient_recvdata') while 1: gevent.sleep() try: self._recvbuffer = self._socket.read_until('\r\n').strip() logger.debug('<<< %s' % requests.compat.unquote(self._recvbuffer)) except gevent.GreenletExit: break except: # If something happened during read, abandon reader. logger.error('Exception at socket read. AceClient destroyed') if not self._shuttingDown.ready(): self._shuttingDown.set() return else: # Parsing everything only if the string is not empty # HELLOTS if self._recvbuffer.startswith('HELLOTS'): # version=engine_version version_code=version_code key=request_key http_port=http_port self._result.set({ k:v for k,v in (x.split('=') for x in self._recvbuffer.split() if '=' in x) }) # NOTREADY elif self._recvbuffer.startswith('NOTREADY'): self._result.set('NOTREADY') # AUTH elif self._recvbuffer.startswith('AUTH'): self._result.set(self._recvbuffer.split()[1]) # user_auth_level # START elif self._recvbuffer.startswith('START'): # url [ad=1 [interruptable=1]] [stream=1] [pos=position] params = { k:v for k,v in (x.split('=') for x in self._recvbuffer.split() if '=' in x) } if not self._seekback or self._started_again.ready() or params.get('stream','') is not '1': # If seekback is disabled, we use link in first START command. # If seekback is enabled, we wait for first START command and # ignore it, then do seekback in first EVENT position command # AceStream sends us STOP and START again with new link. # We use only second link then. self._urlresult.set(self._recvbuffer.split()[1]) # url for play # LOADRESP elif self._recvbuffer.startswith('LOADRESP'): self._result.set(requests.compat.json.loads(requests.compat.unquote(''.join(self._recvbuffer.split()[2:])))) # STATE elif self._recvbuffer.startswith('STATE'): # tuple of (state_id, time of appearance) self._state.set((self._recvbuffer.split()[1], time.time())) # STATUS elif self._recvbuffer.startswith('STATUS'): self._tempstatus = self._recvbuffer.split()[1] if self._tempstatus.startswith('main:idle'): pass elif self._tempstatus.startswith('main:loading'): pass elif self._tempstatus.startswith('main:starting'): pass elif self._tempstatus.startswith('main:check'): pass elif self._tempstatus.startswith('main:wait'): pass elif self._tempstatus.startswith(('main:prebuf','main:buf')): pass #progress;time #values = list(map(int, self._tempstatus.split(';')[3:])) #self._status.set({k: v for k, v in zip(AceConst.STATUS, values)}) elif self._tempstatus.startswith('main:dl'): pass #values = list(map(int, self._tempstatus.split(';')[1:])) #self._status.set({k: v for k, v in zip(AceConst.STATUS, values)}) elif self._tempstatus.startswith('main:err'): # err;error_id;error_message self._status.set_exception(AceException('%s with message %s' % (self._tempstatus.split(';')[0],self._tempstatus.split(';')[2]))) # CID elif self._recvbuffer.startswith('##'): self._cidresult.set(self._recvbuffer) # INFO elif self._recvbuffer.startswith('INFO'): pass # EVENT elif self._recvbuffer.startswith('EVENT'): self._tempevent = self._recvbuffer.split() if self._seekback and not self._started_again.ready() and 'livepos' in self._tempevent: params = { k:v for k,v in (x.split('=') for x in self._tempevent if '=' in x) } self._write(AceMessage.request.LIVESEEK(int(params['last']) - self._seekback)) self._started_again.set() elif 'getuserdata' in self._tempevent: self._write(AceMessage.request.USERDATA(self._gender, self._age)) elif 'cansave' in self._tempevent: pass elif 'showurl' in self._tempevent: pass elif 'download_stopped' in self._tempevent: pass # PAUSE elif self._recvbuffer.startswith('PAUSE'): self._write(AceMessage.request.EVENT('pause')) # RESUME elif self._recvbuffer.startswith('RESUME'): self._write(AceMessage.request.EVENT('play')) # STOP elif self._recvbuffer.startswith('STOP'): pass # SHUTDOWN elif self._recvbuffer.startswith('SHUTDOWN'): self._socket.close() logger.debug('AceClient destroyed') break
class AceClient(object): def __init__(self, clientcounter, ace, connect_timeout=5, result_timeout=10): # Telnet socket response buffer self._recvbuffer = None # AceEngine socket self._socket = None # AceEngine read result timeout self._resulttimeout = result_timeout # AceEngine product key self._product_key = None # Result (Created with AsyncResult() on call) self._auth = AsyncResult() # Result for START URL self._url = AsyncResult() # Response time from AceEngine to get URL or DATA self._videotimeout = None # Result for CID self._cid = AsyncResult() # Result fo LOADASYNC self._loadasync = AsyncResult() # Current STATUS self._status = AsyncResult() # Current EVENT self._event = AsyncResult() # Current STATE self._state = AsyncResult() # Current AUTH self._gender = self._age = None # Seekback seconds. self._seekback = None # Did we get START command again? For seekback. self._started_again = Event() # ClientCounter self._clientcounter = clientcounter # AceConfig.ace self._ace = ace try: self._socket = Telnet(self._ace['aceHostIP'], self._ace['aceAPIport'], connect_timeout) logging.debug( 'Successfully connected to AceStream on {aceHostIP}:{aceAPIport}' .format(**self._ace)) except: errmsg = 'The are no alive AceStream Engines found!' raise AceException(errmsg) def destroy(self): ''' AceClient Destructor ''' # Send SHUTDOWN to AceEngine try: self._write(AceMessage.request.SHUTDOWN) except: pass # Ignore exceptions on destroy finally: self._clientcounter.idleAce = None def reset(self): ''' Reset initial values ''' self._started_again.clear() self._url.set() self._loadasync.set() self._cid.set() self._status.set() self._event.set() self._state.set() def _write(self, message): try: self._socket.write('%s\r\n' % message) logging.debug('>>> %s' % message) except gevent.socket.error: raise AceException('Error writing data to AceEngine API port') def aceInit(self, gender=AceConst.SEX_MALE, age=AceConst.AGE_25_34, product_key=None, videoseekback=0, videotimeout=30): self._gender = gender self._age = age self._product_key = product_key self._seekback = videoseekback self._videotimeout = videotimeout self._started_again.clear() # Spawning telnet data reader with recvbuffer read timeout (allowable STATE 0 (IDLE) time) gevent.spawn( wrap_errors((EOFError, gevent.socket.error), self._recvData), self._videotimeout).link_exception(lambda x: logging.error( 'Error reading data from AceEngine API port')) self._auth = AsyncResult() self._write(AceMessage.request.HELLO) # Sending HELLOBG try: params = self._auth.get(timeout=self._resulttimeout) except gevent.Timeout as t: errmsg = 'Engine response time %s exceeded. HELLOTS not resived!' % t raise AceException(errmsg) if isinstance(params, dict): self._write( AceMessage.request.READY(params.get('key', ''), self._product_key)) else: self._auth.set(params) try: if self._auth.get( timeout=self._resulttimeout ) == 'NOTREADY': # Get NOTREADY instead AUTH user_auth_level errmsg = 'NOTREADY recived from AceEngine! Wrong acekey?' raise AceException(errmsg) except gevent.Timeout as t: errmsg = 'Engine response time %s exceeded. AUTH not resived!' % t raise AceException(errmsg) if int(params.get('version_code', 0)) >= 3003600: # Display download_stopped massage params_dict = {'use_stop_notifications': '1'} self._write(AceMessage.request.SETOPTIONS(params_dict)) def START(self, command, paramsdict, acestreamtype): ''' Start video method. Get url for play from AceEngine ''' paramsdict['stream_type'] = ' '.join( ['{}={}'.format(k, v) for k, v in acestreamtype.items()]) self._url = AsyncResult() self._write(AceMessage.request.START(command.upper(), paramsdict)) try: return self._url.get(timeout=self._videotimeout) except gevent.Timeout as t: errmsg = 'START URL not received! Engine response time %s exceeded' % t raise AceException(errmsg) def STOP(self): ''' Stop video method ''' if self._state: self._write(AceMessage.request.STOP) def LOADASYNC(self, command, params, sessionid='0'): self._loadasync = AsyncResult() self._write( AceMessage.request.LOADASYNC(command.upper(), sessionid, params)) try: return self._loadasync.get( timeout=self._resulttimeout) # Get _contentinfo json except gevent.Timeout as t: errmsg = 'Engine response %s time exceeded. LOADARESP not resived!' % t raise AceException(errmsg) def GETCONTENTINFO(self, command, value, sessionid='0'): paramsdict = { command: value, 'developer_id': '0', 'affiliate_id': '0', 'zone_id': '0' } return self.LOADASYNC(command, paramsdict, sessionid) def GETCID(self, command, value): contentinfo = self.GETCONTENTINFO(command, value) if contentinfo['status'] in (1, 2): paramsdict = { 'checksum': contentinfo['checksum'], 'infohash': contentinfo['infohash'], 'developer_id': '0', 'affiliate_id': '0', 'zone_id': '0' } self._cid = AsyncResult() self._write(AceMessage.request.GETCID(paramsdict)) try: return self._cid.get(timeout=self._resulttimeout)[2:] # ##CID except gevent.Timeout as t: errmsg = 'Engine response time %s exceeded. CID not resived!' % t raise AceException(errmsg) else: errmsg = 'LOADASYNC returned error with message: %s' % contentinfo[ 'message'] raise AceException(errmsg) def GETINFOHASH(self, command, value, sessionid='0', idx=0): contentinfo = self.GETCONTENTINFO(command, value, sessionid) if contentinfo['status'] in (1, 2): return contentinfo['infohash'], next( iter([x[0] for x in contentinfo['files'] if x[1] == int(idx)]), None) elif contentinfo['status'] == 0: errmsg = 'LOADASYNC returned status 0: The transport file does not contain audio/video files' raise AceException(errmsg) else: errmsg = 'LOADASYNC returned error with message: %s' % contentinfo[ 'message'] raise AceException(errmsg) def _recvData(self, timeout=30): ''' Data receiver method for greenlet ''' while 1: # Destroy socket connection if AceEngine STATE 0 (IDLE) and we didn't read anything from socket until Nsec with gevent.Timeout(timeout, False): try: self._recvbuffer = self._socket.read_until('\r\n', None).strip() except gevent.Timeout: self.destroy() except gevent.socket.timeout: pass except: raise else: logging.debug('<<< %s' % unquote(self._recvbuffer)) # Parsing everything only if the string is not empty # HELLOTS if self._recvbuffer.startswith('HELLOTS'): #version=engine_version version_code=version_code key=request_key http_port=http_port self._auth.set({ k: v for k, v in (x.split('=') for x in self._recvbuffer.split() if '=' in x) }) # NOTREADY elif self._recvbuffer.startswith('NOTREADY'): self._auth.set('NOTREADY') # AUTH elif self._recvbuffer.startswith('AUTH'): self._auth.set( self._recvbuffer.split()[1]) # user_auth_level # START elif self._recvbuffer.startswith('START'): # url [ad=1 [interruptable=1]] [stream=1] [pos=position] params = { k: v for k, v in (x.split('=') for x in self._recvbuffer.split() if '=' in x) } if not self._seekback or self._started_again.ready( ) or params.get('stream', '') is not '1': # If seekback is disabled, we use link in first START command. # If seekback is enabled, we wait for first START command and # ignore it, then do seekback in first EVENT position command # AceStream sends us STOP and START again with new link. # We use only second link then. self._url.set( self._recvbuffer.split()[1]) # url for play # LOADRESP elif self._recvbuffer.startswith('LOADRESP'): self._loadasync.set( json.loads( unquote(''.join( self._recvbuffer.split()[2:])))) # STATE elif self._recvbuffer.startswith('STATE'): self._state.set(self._recvbuffer.split() [1]) # STATE state_id -> STATE_NAME # STATUS elif self._recvbuffer.startswith('STATUS'): self._tempstatus = self._recvbuffer.split()[1] stat = [self._tempstatus.split(';')[0].split(':')[1] ] # main:???? if self._tempstatus.startswith('main:idle'): pass elif self._tempstatus.startswith('main:loading'): pass elif self._tempstatus.startswith('main:starting'): pass elif self._tempstatus.startswith('main:check'): pass elif self._tempstatus.startswith('main:err'): pass # err;error_id;error_message elif self._tempstatus.startswith('main:dl'): #dl; stat.extend( map(int, self._tempstatus.split(';')[1:])) elif self._tempstatus.startswith( 'main:wait'): #wait;time; stat.extend( map(int, self._tempstatus.split(';')[2:])) elif self._tempstatus.startswith( ('main:prebuf', 'main:buf')): #buf;progress;time; stat.extend( map(int, self._tempstatus.split(';')[3:])) try: self._status.set({ k: v for k, v in zip(AceConst.STATUS, stat) }) # dl, wait, buf, prebuf except: self._status.set( {'status': stat[0]}) # idle, loading, starting, check # CID elif self._recvbuffer.startswith('##'): self._cid.set(self._recvbuffer) # INFO elif self._recvbuffer.startswith('INFO'): pass # EVENT elif self._recvbuffer.startswith('EVENT'): self._tempevent = self._recvbuffer.split() if self._seekback and not self._started_again.ready( ) and 'livepos' in self._tempevent: params = { k: v for k, v in (x.split('=') for x in self._tempevent if '=' in x) } self._write( AceMessage.request.LIVESEEK( int(params['last']) - self._seekback)) self._started_again.set() elif 'getuserdata' in self._tempevent: self._write( AceMessage.request.USERDATA( self._gender, self._age)) elif 'cansave' in self._tempevent: pass elif 'showurl' in self._tempevent: pass elif 'download_stopped' in self._tempevent: pass # PAUSE elif self._recvbuffer.startswith('PAUSE'): pass #self._write(AceMessage.request.EVENT('pause')) # RESUME elif self._recvbuffer.startswith('RESUME'): pass #self._write(AceMessage.request.EVENT('play')) # STOP elif self._recvbuffer.startswith('STOP'): pass #self._write(AceMessage.request.EVENT('stop')) # SHUTDOWN elif self._recvbuffer.startswith('SHUTDOWN'): self._socket.close() break
class AceClient(object): def __init__(self, acehostslist, connect_timeout=5, result_timeout=10): # Receive buffer self._recvbuffer = None # Ace stream socket self._socket = None # Result timeout self._resulttimeout = float(result_timeout) # Shutting down flag self._shuttingDown = Event() # Product key self._product_key = None # Current STATUS self._status = None # Current EVENT self._event = None # Current STATE self._state = None # Current AUTH self._gender = None self._age = None # Result (Created with AsyncResult() on call) self._result = AsyncResult() # Seekback seconds. self._seekback = None # Did we get START command again? For seekback. self._started_again = Event() self._idleSince = time.time() self._streamReaderConnection = None self._streamReaderState = Event() self._streamReaderQueue = gevent.queue.Queue( maxsize=1024) # Ring buffer with max number of chunks in queue self._engine_version_code = None # Logger logger = logging.getLogger('AceClient') # Try to connect AceStream engine for AceEngine in acehostslist: try: self._socket = Telnet(AceEngine[0], AceEngine[1], connect_timeout) AceConfig.acehost, AceConfig.aceAPIport, AceConfig.aceHTTPport = AceEngine[ 0], AceEngine[1], AceEngine[2] logger.debug('Successfully connected to AceStream on %s:%d' % (AceEngine[0], AceEngine[1])) break except: logger.debug('The are no alive AceStream on %s:%d' % (AceEngine[0], AceEngine[1])) pass # Spawning recvData greenlet if self._socket: gevent.spawn(self._recvData) else: logger.error('The are no alive AceStream Engines found') return def destroy(self): ''' AceClient Destructor ''' logger = logging.getLogger('AceClient_destroy') # Logger if self._shuttingDown.ready(): return # Already in the middle of destroying self._result.set() # Trying to disconnect try: logger.debug('Destroying AceStream client.....') self._shuttingDown.set() self._write(AceMessage.request.SHUTDOWN) except: pass # Ignore exceptions on destroy finally: self._shuttingDown.set() def reset(self): self._idleSince = time.time() self._started_again.clear() self._streamReaderState.clear() self._result.set() def _write(self, message): try: logger = logging.getLogger('AceClient_write') logger.debug('>>> %s' % message) self._socket.write('%s\r\n' % message) except EOFError as e: raise AceException('Write error! %s' % repr(e)) def aceInit(self, gender=AceConst.SEX_MALE, age=AceConst.AGE_25_34, product_key=AceConfig.acekey): self._product_key = product_key self._gender = gender self._age = age self._seekback = AceConfig.videoseekback self._started_again.clear() logger = logging.getLogger('AceClient_aceInit') self._result = AsyncResult() self._write(AceMessage.request.HELLO) # Sending HELLOBG try: params = self._getResult(timeout=self._resulttimeout) except: errmsg = 'HELLOTS not resived from engine!' raise AceException(errmsg) return self._engine_version_code = int(params.get('version_code', 0)) self._result = AsyncResult() self._write( AceMessage.request.READY(params.get('key', ''), self._product_key)) if not self._getResult(timeout=self._resulttimeout ): # Get NOTREADY instead AUTH user_auth_level errmsg = 'NOTREADY recived from AceEngine! Wrong acekey?' raise AceException(errmsg) return if self._engine_version_code >= 3003600: # Display download_stopped massage params_dict = {'use_stop_notifications': '1'} self._write(AceMessage.request.SETOPTIONS(params_dict)) def _getResult(self, timeout=10.0): logger = logging.getLogger('AceClient_getResult') # Logger try: return self._result.get(timeout=timeout) except gevent.Timeout: errmsg = 'Engine response time exceeded. getResult timeout from %s:%s' % ( AceConfig.acehost, AceConfig.aceAPIport) raise AceException(errmsg) def START(self, datatype, value, stream_type): ''' Start video method Returns the url provided by AceEngine ''' if stream_type == 'hls' and self._engine_version_code >= 3010500: params_dict = { 'output_format': stream_type, 'transcode_audio': AceConfig.transcode_audio, 'transcode_mp3': AceConfig.transcode_mp3, 'transcode_ac3': AceConfig.transcode_ac3, 'preferred_audio_language': AceConfig.preferred_audio_language } else: params_dict = {'output_format': 'http'} self._result = AsyncResult() self._write( AceMessage.request.START( datatype.upper(), value, ' '.join( ['{}={}'.format(k, v) for k, v in params_dict.items()]))) return self._getResult(timeout=float( AceConfig.videotimeout)) # Get url for play from AceEngine def STOP(self): ''' Stop video method ''' self._result = AsyncResult() self._write(AceMessage.request.STOP) self._getResult(timeout=self._resulttimeout ) # Get STATE 0(IDLE) after sendig STOP to AceEngine def LOADASYNC(self, datatype, params): self._result = AsyncResult() self._write( AceMessage.request.LOADASYNC( datatype.upper(), random.randint(1, AceConfig.maxconns * 10000), params)) return self._getResult(timeout=self._resulttimeout ) # Get _contentinfo json from AceEngine def GETCONTENTINFO(self, datatype, value): params_dict = { datatype: value, 'developer_id': '0', 'affiliate_id': '0', 'zone_id': '0' } return self.LOADASYNC(datatype, params_dict) def GETCID(self, datatype, url): contentinfo = self.GETCONTENTINFO(datatype, url) if contentinfo['status'] in (1, 2): params_dict = { 'checksum': contentinfo['checksum'], 'infohash': contentinfo['infohash'], 'developer_id': '0', 'affiliate_id': '0', 'zone_id': '0' } self._result = AsyncResult() self._write(AceMessage.request.GETCID(params_dict)) cid = self._result.get(timeout=5.0) else: cid = None errmsg = 'LOADASYNC returned error with message: %s' % contentinfo[ 'message'] raise AceException(errmsg) return '' if cid is None or cid == '' else cid[2:] def GETINFOHASH(self, datatype, url, idx=0): contentinfo = self.GETCONTENTINFO(datatype, url) if contentinfo['status'] in (1, 2): return contentinfo['infohash'], [ x[0] for x in contentinfo['files'] if x[1] == int(idx) ][0] elif contentinfo['status'] == 0: errmsg = 'LOADASYNC returned status 0: The transport file does not contain audio/video files' raise AceException(errmsg) else: errmsg = 'LOADASYNC returned error with message: %s' % contentinfo[ 'message'] raise AceException(errmsg) return None, None def startStreamReader(self, url, cid, counter, req_headers=None): logger = logging.getLogger('StreamReader') logger.debug('Open video stream: %s' % url) transcoder = None logger.debug('Get headers from client: %s' % req_headers) try: if url.endswith('.m3u8'): logger.warning( 'HLS stream detected. Ffmpeg transcoding started') popen_params = { 'bufsize': requests.models.CONTENT_CHUNK_SIZE, 'stdout': gevent.subprocess.PIPE, 'stderr': None, 'shell': False } if AceConfig.osplatform == 'Windows': ffmpeg_cmd = 'ffmpeg.exe ' CREATE_NO_WINDOW = 0x08000000 CREATE_NEW_PROCESS_GROUP = 0x00000200 DETACHED_PROCESS = 0x00000008 popen_params.update(creationflags=CREATE_NO_WINDOW | DETACHED_PROCESS | CREATE_NEW_PROCESS_GROUP) else: ffmpeg_cmd = 'ffmpeg ' ffmpeg_cmd += '-hwaccel auto -hide_banner -loglevel fatal -re -i %s -c copy -f mpegts -' % url transcoder = gevent.subprocess.Popen(ffmpeg_cmd.split(), **popen_params) out = transcoder.stdout else: self._streamReaderConnection = requests.get( url, headers=req_headers, stream=True, timeout=(5, AceConfig.videotimeout)) self._streamReaderConnection.raise_for_status( ) # raise an exception for error codes (4xx or 5xx) out = self._streamReaderConnection.raw self._streamReaderState.set() self._write(AceMessage.request.EVENT('play')) while 1: gevent.sleep() clients = counter.getClients(cid) if clients: try: chunk = out.read(requests.models.CONTENT_CHUNK_SIZE) try: self._streamReaderQueue.put_nowait(chunk) except gevent.queue.Full: self._streamReaderQueue.get_nowait() self._streamReaderQueue.put_nowait(chunk) except requests.packages.urllib3.exceptions.ReadTimeoutError: logger.warning( 'No data received from AceEngine for %ssec - broadcast stoped' % AceConfig.videotimeout) break except: break else: for c in clients: try: c.queue.put(chunk, timeout=5) except gevent.queue.Full: if len(clients) > 1: logger.warning( 'Client %s does not read data from buffer until 5sec - disconnect it' % c.handler.clientip) c.destroy() except gevent.GreenletExit: pass else: logger.debug('All clients disconnected - broadcast stoped') break except requests.exceptions.HTTPError as err: logger.error( 'An http error occurred while connecting to aceengine: %s' % repr(err)) except requests.exceptions.RequestException: logger.error( 'There was an ambiguous exception that occurred while handling request' ) except Exception as err: logger.error('Unexpected error in streamreader %s' % repr(err)) finally: self.closeStreamReader() if transcoder is not None: try: transcoder.kill() logger.warning('Ffmpeg transcoding stoped') except: pass counter.deleteAll(cid) def closeStreamReader(self): logger = logging.getLogger('StreamReader') self._streamReaderState.clear() if self._streamReaderConnection: logger.debug('Close video stream: %s' % self._streamReaderConnection.url) self._streamReaderConnection.close() self._streamReaderQueue.queue.clear() def _recvData(self): ''' Data receiver method for greenlet ''' logger = logging.getLogger('AceClient_recvdata') while 1: gevent.sleep() try: self._recvbuffer = self._socket.read_until('\r\n').strip() logger.debug('<<< %s' % requests.compat.unquote(self._recvbuffer)) except: # If something happened during read, abandon reader. logger.error('Exception at socket read. AceClient destroyed') if not self._shuttingDown.ready(): self._shuttingDown.set() return else: # Parsing everything only if the string is not empty # HELLOTS if self._recvbuffer.startswith('HELLOTS'): # version=engine_version version_code=version_code key=request_key http_port=http_port self._result.set({ k: v for k, v in (x.split('=') for x in self._recvbuffer.split() if '=' in x) }) # NOTREADY elif self._recvbuffer.startswith('NOTREADY'): self._result.set(False) # AUTH elif self._recvbuffer.startswith('AUTH'): self._result.set( self._recvbuffer.split()[1]) # user_auth_level # START elif self._recvbuffer.startswith('START'): # url [ad=1 [interruptable=1]] [stream=1] [pos=position] params = { k: v for k, v in (x.split('=') for x in self._recvbuffer.split() if '=' in x) } if not self._seekback or self._started_again.ready( ) or params.get('stream', '') is not '1': # If seekback is disabled, we use link in first START command. # If seekback is enabled, we wait for first START command and # ignore it, then do seekback in first EVENT position command # AceStream sends us STOP and START again with new link. # We use only second link then. self._result.set( self._recvbuffer.split()[1]) # url for play # LOADRESP elif self._recvbuffer.startswith('LOADRESP'): self._result.set( requests.compat.json.loads( requests.compat.unquote(' '.join( self._recvbuffer.split()[2:])))) # STATE elif self._recvbuffer.startswith('STATE'): self._state = self._recvbuffer.split()[1] # state_id if self._state == '0': # 0(IDLE) self._result.set( self._write(AceMessage.request.EVENT('stop'))) elif self._state == '1': pass # 1 (PREBUFFERING) elif self._state == '2': pass # 2 (DOWNLOADING) elif self._state == '3': pass # 3 (BUFFERING) elif self._state == '4': pass # 4 (COMPLETED) elif self._state == '5': pass # 5 (CHECKING) elif self._state == '6': pass # 6 (ERROR) # STATUS elif self._recvbuffer.startswith('STATUS'): pass # CID elif self._recvbuffer.startswith('##'): self._result.set(self._recvbuffer) # INFO elif self._recvbuffer.startswith('INFO'): pass # EVENT elif self._recvbuffer.startswith('EVENT'): self._event = self._recvbuffer.split() if 'livepos' in self._event: if self._seekback and not self._started_again.ready( ): # if seekback params = { k: v for k, v in (x.split('=') for x in self._event if '=' in x) } self._write( AceMessage.request.LIVESEEK( int(params['last']) - self._seekback)) self._started_again.set() elif 'getuserdata' in self._event: self._write( AceMessage.request.USERDATA( self._gender, self._age)) elif 'cansave' in self._event: pass elif 'showurl' in self._event: pass elif 'download_stopped' in self._event: pass # PAUSE elif self._recvbuffer.startswith('PAUSE'): self._write(AceMessage.request.EVENT('pause')) # RESUME elif self._recvbuffer.startswith('RESUME'): self._write(AceMessage.request.EVENT('play')) # STOP elif self._recvbuffer.startswith('STOP'): pass # SHUTDOWN elif self._recvbuffer.startswith('SHUTDOWN'): self._socket.close() logger.debug('AceClient destroyed') return
class RestTransactionMixin(ATransaction): """Wrap ATransaction class to create RestTransaction class""" def __init__(self, _id, callback_url: str, ping_timeout: Seconds, local_timeout: Seconds): """ :param _id: transaction_id (local) :param callback_url: url to send results of commit :param ping_timeout: timeout (seconds) of checking service status :param local_timeout: timeout (seconds) of local transaction """ ATransaction.__init__(self, _id) self.callback_url = callback_url self.ping_timeout = ping_timeout self.local_timeout = local_timeout self.key = sha256( bytes(str(self.id) + str(int(time.time() * 10**6) ^ randint(0, 2**20)), encoding="utf-8")).hexdigest() debug_SSE.event({ "event": "init", "t": datetime.now(), "data": { "callback_url": self.callback_url, "local_timeout": self.local_timeout * 1000, "ping_timeout": self.ping_timeout * 1000, "key": self.key, "_id": self.id } }) # DEBUG init self._ping = Event() # store ping event self.ping_timeout_thread_obj: Greenlet = None @g_async def _spawn(self): self.ping_timeout_thread_obj = self.ping_timeout_thread( ) # THREAD:1, loop self.ready_commit_thread_obj = self.ready_commit_handler() # THREAD:1 @g_async def ping_timeout_thread(self): while not (self.done.ready() or self.fail.ready()): debug_SSE.event({ "event": "wait_ping", "t": datetime.now(), "data": None }) # DEBUG wait_ping w = wait((self._ping, self.done, self.fail), count=1, timeout=self.ping_timeout * 2) # BLOCK, ping_timeout * 2 if not len(w): debug_SSE.event({ "event": "fail", "t": datetime.now(), "data": "ping timeout" }) # DEBUG ping timeout super().do_rollback() break if self._ping.ready(): debug_SSE.event({ "event": "ping", "t": datetime.now(), "data": None }) # DEBUG ping self._ping.clear() # EMIT(-ping) sleep() def ping(self) -> bool: """Ping request handler""" if not (self.fail.ready() or self.done.ready()): self._ping.set() # EMIT(ping) return True return False @g_async def ready_commit_handler(self): wait((self.ready_commit, self.fail), count=1, timeout=self.local_timeout) # BLOCK, local_timeout if not self.fail.ready(): debug_SSE.event({ "event": "ready_commit", "t": datetime.now(), "data": None }) # DEBUG ready_commit data = {"key": self.key, "response": {"data": self.result.get()}} rp = requests.put(self.callback_url, headers={"Connection": "close"}, json=data, timeout=5) @g_async def do_commit(self): if not self.fail.ready() and self.ready_commit.ready( ) and self.result.ready(): super().do_commit() debug_SSE.event({ "event": "commit", "t": datetime.now(), "data": None }) # DEBUG commit data = {"key": self.key, "done": True} rp = requests.put(self.callback_url, headers={"Connection": "close"}, json=data) debug_SSE.event({ "event": "done", "t": datetime.now(), "data": None }) # DEBUG done @g_async def do_rollback(self): super().do_rollback() debug_SSE.event({ "event": "rollback", "t": datetime.now(), "data": None }) # DEBUG rollback
class Websocket(object): def __init__(self, websocket, readonly=False, remote=None): self._ws = websocket self._closed = Event() self._readonly = readonly self._remote = remote self._tasks = None def __repr__(self): return "%s @ %s" % (self.__class__.__name__, self._remote or "%x" % (id(self), )) def run(self, task): assert self._tasks is None try: self._tasks = [ gevent.spawn(self._recvloop, task), ] if not self._readonly: gevent.spawn(self._sendloop, task) self.wait() finally: self.stop() @property def closed(self): return self._closed.ready() def wait(self): assert self._tasks is not None gevent.joinall(self._tasks) def _recvloop(self, task): from geventwebsocket.exceptions import WebSocketError while not self.closed: try: data = self._ws.receive() except WebSocketError: LOG.exception("%r recvloop", self) break if data in (StopIteration, None): break try: msg = json.loads(data) except ValueError: LOG.exception("%r recv decode error for %r", self, data) continue task.output.send(msg) LOG.debug("%r recvloop finished", self) self.stop() def _sendloop(self, task): for msg in task.input.watch(): #LOG.info("sendloop Got %r", msg) if msg is StopIteration or self.closed: break try: self._ws.send(json.dumps(msg)) except Exception: LOG.exception("%r send error for %r", self, msg) continue LOG.debug("%r sendloop finished", self) self.stop() def stop(self): if not self.closed: self._ws.close() self._closed.set() LOG.debug("%r stopping", self) @classmethod def communicate(cls, websocket, bridge): ws2b = cls(websocket) ws2b.run(bridge)
class AServer(metaclass=ABCMeta): _logger = logging.getLogger("AServer") @abstractmethod def __init__(self, address: Tuple[str, int], max_connections=100): self.max_connections = max_connections self.methods_map = {} # type: Dict[str, HandlerType] self._stop = Event() @abstractmethod def serve_forever(self): pass def stop(self): self._stop.set() def method(self, f: HandlerType): """ Decorator. Create TCP method called same as handler function :param f: :return: """ if f.__name__ not in self.methods_map: self.methods_map[f.__name__] = f else: raise NameError(f.__name__) return f # noinspection PyUnresolvedReferences def _handler(self, socket_obj: socket.SocketType, address: Tuple[str, int]): while not self._stop.ready(): try: raw = receive(socket_obj) # BLOCK except OSError as e: self.log("?", e, "Call failed", "WARNING") return if not raw: return try: method, json = Request.decode(raw).values except Tcp500 as e: try: method, _ = str(raw[:-1], encoding="utf-8").split("\n", 1) except: method = str(None) self.log(method, "500", str(e.args), "WARNING") socket_obj.sendall(Response(e.status, e).encode()) return if method not in self.methods_map: e = Tcp404("method not found") self.log(method, "404", "Method not found", "WARNING") socket_obj.sendall(Response(e.status, e).encode()) return try: resp = self.methods_map[method](json) self.log(method, "200", "", "INFO") except ATcpException as e: self.log(method, str(e.status), str(e.args), "WARNING") resp = Response(e.status, e.text) except Exception as e: self.log(method, "500", str(e.args), "WARNING") resp = Response(500, "Fail") else: if type(resp) is tuple: resp = Response(*resp) socket_obj.sendall(resp.encode()) @property def logger(self): return self._logger def log(self, method: str, status: Union[str, Exception], msg: str, level: str): s = f"{method:20s} | {status:5s} | {msg}" self.logger.log(logging.getLevelName(level), s)
class UserAddressManager: """ Matrix user <-> eth address mapping and user / address reachability helper. In Raiden the smallest unit of addressability is a node with an associated Ethereum address. In Matrix it's a user. Matrix users are (at the moment) bound to a specific homeserver. Since we want to provide resiliency against unavailable homeservers a single Raiden node with a single Ethereum address can be in control over multiple Matrix users on multiple homeservers. Therefore we need to perform a many-to-one mapping of Matrix users to Ethereum addresses. Each Matrix user has a presence state (ONLINE, OFFLINE). One of the preconditions of running a Raiden node is that there can always only be one node online for a particular address at a time. That means we can synthesize the reachability of an address from the user presence states. This helper internally tracks both the user presence and address reachability for addresses that have been marked as being 'interesting' (by calling the `.add_address()` method). Additionally it provides the option of passing callbacks that will be notified when presence / reachability change. """ def __init__( self, client: GMatrixClient, get_user_callable: Callable[[Union[User, str]], User], address_reachability_changed_callback: Callable[ [Address, AddressReachability], None], user_presence_changed_callback: Optional[Callable[[User, UserPresence], None]] = None, _log_context: Optional[Dict[str, Any]] = None, ) -> None: self._client = client self._get_user = get_user_callable self._address_reachability_changed_callback = address_reachability_changed_callback self._user_presence_changed_callback = user_presence_changed_callback self._stop_event = Event() self._reset_state() self._log_context = _log_context self._log = None self._listener_id: Optional[UUID] = None def start(self) -> None: """ Start listening for presence updates. Should be called before ``.login()`` is called on the underlying client. """ assert self._listener_id is None, "UserAddressManager.start() called twice" self._stop_event.clear() self._listener_id = self._client.add_presence_listener( self._presence_listener) def stop(self) -> None: """ Stop listening on presence updates. """ assert self._listener_id is not None, "UserAddressManager.stop() called before start" self._stop_event.set() self._client.remove_presence_listener(self._listener_id) self._listener_id = None self._log = None self._reset_state() @property def known_addresses(self) -> KeysView[Address]: """ Return all addresses we keep track of """ return self._address_to_userids.keys() def is_address_known(self, address: Address) -> bool: """ Is the given ``address`` reachability being monitored? """ return address in self._address_to_userids def add_address(self, address: Address): """ Add ``address`` to the known addresses that are being observed for reachability. """ # Since _address_to_userids is a defaultdict accessing the key creates the entry _ = self._address_to_userids[address] def add_userid_for_address(self, address: Address, user_id: str): """ Add a ``user_id`` for the given ``address``. Implicitly adds the address if it was unknown before. """ self._address_to_userids[address].add(user_id) def add_userids_for_address(self, address: Address, user_ids: Iterable[str]): """ Add multiple ``user_ids`` for the given ``address``. Implicitly adds any addresses if they were unknown before. """ self._address_to_userids[address].update(user_ids) def get_userids_for_address(self, address: Address) -> Set[str]: """ Return all known user ids for the given ``address``. """ if not self.is_address_known(address): return set() return self._address_to_userids[address] def get_userid_presence(self, user_id: str) -> UserPresence: """ Return the current presence state of ``user_id``. """ return self._userid_to_presence.get(user_id, UserPresence.UNKNOWN) def get_address_reachability(self, address: Address) -> AddressReachability: """ Return the current reachability state for ``address``. """ return self._address_to_reachability.get(address, AddressReachability.UNKNOWN) def force_user_presence(self, user: User, presence: UserPresence): """ Forcibly set the ``user`` presence to ``presence``. This method is only provided to cover an edge case in our use of the Matrix protocol and should **not** generally be used. """ self._userid_to_presence[user.user_id] = presence def populate_userids_for_address(self, address: Address, force: bool = False): """ Populate known user ids for the given ``address`` from the server directory. If ``force`` is ``True`` perform the directory search even if there already are known users. """ if force or not self.get_userids_for_address(address): self.add_userids_for_address( address, (user.user_id for user in self._client.search_user_directory( to_normalized_address(address)) if self._validate_userid_signature(user)), ) def refresh_address_presence(self, address: Address): """ Update synthesized address presence state from cached user presence states. Triggers callback (if any) in case the state has changed. This method is only provided to cover an edge case in our use of the Matrix protocol and should **not** generally be used. """ composite_presence = { self._fetch_user_presence(uid) for uid in self._address_to_userids[address] } # Iterate over UserPresence in definition order (most to least online) and pick # first matching state new_presence = UserPresence.UNKNOWN for presence in UserPresence.__members__.values(): if presence in composite_presence: new_presence = presence break new_address_reachability = USER_PRESENCE_TO_ADDRESS_REACHABILITY[ new_presence] prev_addresss_reachability = self.get_address_reachability(address) if new_address_reachability == prev_addresss_reachability: # Cached address reachability matches new state, do nothing return self.log.debug( "Changing address reachability state", address=to_checksum_address(address), prev_state=prev_addresss_reachability, state=new_address_reachability, ) self._address_to_reachability[address] = new_address_reachability self._address_reachability_changed_callback(address, new_address_reachability) def _presence_listener(self, event: Dict[str, Any]): """ Update cached user presence state from Matrix presence events. Due to the possibility of nodes using accounts on multiple homeservers a composite address state is synthesised from the cached individual user presence states. """ if self._stop_event.ready(): return user_id = event["sender"] if event["type"] != "m.presence" or user_id == self._user_id: return user = self._get_user(user_id) user.displayname = event["content"].get( "displayname") or user.displayname address = self._validate_userid_signature(user) if not address: # Malformed address - skip return # not a user we've whitelisted, skip if not self.is_address_known(address): return self.add_userid_for_address(address, user_id) new_state = UserPresence(event["content"]["presence"]) if new_state == self.get_userid_presence(user_id): # Cached presence state matches, no action required return self.log.debug( "Changing user presence state", user_id=user_id, prev_state=self._userid_to_presence.get(user_id), state=new_state, ) self._userid_to_presence[user_id] = new_state self.refresh_address_presence(address) if self._user_presence_changed_callback: self._user_presence_changed_callback(user, new_state) def log_status_message(self): while not self._stop_event.ready(): addresses_uids_presence = { to_checksum_address(address): { user_id: self.get_userid_presence(user_id).value for user_id in self.get_userids_for_address(address) } for address in self.known_addresses } log.debug( "Matrix address manager status", addresses_uids_and_presence=addresses_uids_presence, current_user=self._user_id, ) self._stop_event.wait(30) def _reset_state(self): self._address_to_userids: Dict[Address, Set[str]] = defaultdict(set) self._address_to_reachability: Dict[Address, AddressReachability] = dict() self._userid_to_presence: Dict[str, UserPresence] = dict() @property def _user_id(self) -> str: user_id = getattr(self._client, "user_id", None) assert user_id, f"{self.__class__.__name__}._user_id accessed before client login" return user_id def _fetch_user_presence(self, user_id: str) -> UserPresence: if user_id not in self._userid_to_presence: try: presence = UserPresence( self._client.get_user_presence(user_id)) except MatrixRequestError: presence = UserPresence.UNKNOWN self._userid_to_presence[user_id] = presence return self._userid_to_presence[user_id] @staticmethod def _validate_userid_signature(user: User) -> Optional[Address]: return validate_userid_signature(user) @property def log(self) -> BoundLoggerLazyProxy: if not self._log: if not hasattr(self._client, "user_id"): return log self._log = log.bind( **{ "current_user": self._user_id, "node": to_checksum_address( self._user_id.split(":", 1)[0][1:]), **(self._log_context or {}), }) return self._log
class RaidenService(Runnable): """ A Raiden node. """ def __init__( self, chain: BlockChainService, query_start_block: BlockNumber, default_registry: TokenNetworkRegistry, default_secret_registry: SecretRegistry, default_service_registry: Optional[ServiceRegistry], default_one_to_n_address: Optional[Address], transport, raiden_event_handler, message_handler, config, discovery=None, user_deposit=None, ): super().__init__() self.tokennetworkids_to_connectionmanagers: ConnectionManagerDict = dict( ) self.targets_to_identifiers_to_statuses: StatusesDict = defaultdict( dict) self.chain: BlockChainService = chain self.default_registry = default_registry self.query_start_block = query_start_block self.default_one_to_n_address = default_one_to_n_address self.default_secret_registry = default_secret_registry self.default_service_registry = default_service_registry self.config = config self.signer: Signer = LocalSigner(self.chain.client.privkey) self.address = self.signer.address self.discovery = discovery self.transport = transport self.user_deposit = user_deposit self.blockchain_events = BlockchainEvents() self.alarm = AlarmTask(chain) self.raiden_event_handler = raiden_event_handler self.message_handler = message_handler self.stop_event = Event() self.stop_event.set() # inits as stopped self.greenlets: List[Greenlet] = list() self.snapshot_group = 0 self.contract_manager = ContractManager(config["contracts_path"]) self.database_path = config["database_path"] self.wal = None if self.database_path != ":memory:": database_dir = os.path.dirname(config["database_path"]) os.makedirs(database_dir, exist_ok=True) self.database_dir = database_dir # Two raiden processes must not write to the same database. Even # though it's possible the database itself would not be corrupt, # the node's state could. If a database was shared among multiple # nodes, the database WAL would be the union of multiple node's # WAL. During a restart a single node can't distinguish its state # changes from the others, and it would apply it all, meaning that # a node would execute the actions of itself and the others. # # Additionally the database snapshots would be corrupt, because it # would not represent the effects of applying all the state changes # in order. lock_file = os.path.join(self.database_dir, ".lock") self.db_lock = filelock.FileLock(lock_file) else: self.database_path = ":memory:" self.database_dir = None self.serialization_file = None self.db_lock = None self.event_poll_lock = gevent.lock.Semaphore() self.gas_reserve_lock = gevent.lock.Semaphore() self.payment_identifier_lock = gevent.lock.Semaphore() # Flag used to skip the processing of all Raiden events during the # startup. # # Rationale: At the startup, the latest snapshot is restored and all # state changes which are not 'part' of it are applied. The criteria to # re-apply the state changes is their 'absence' in the snapshot, /not/ # their completeness. Because these state changes are re-executed # in-order and some of their side-effects will already have been # completed, the events should be delayed until the state is # synchronized (e.g. an open channel state change, which has already # been mined). # # Incomplete events, i.e. the ones which don't have their side-effects # applied, will be executed once the blockchain state is synchronized # because of the node's queues. self.ready_to_process_events = False def start(self): """ Start the node synchronously. Raises directly if anything went wrong on startup """ assert self.stop_event.ready(), f"Node already started. node:{self!r}" self.stop_event.clear() self.greenlets = list() self.ready_to_process_events = False # set to False because of restarts if self.database_dir is not None: self.db_lock.acquire(timeout=0) assert self.db_lock.is_locked, f"Database not locked. node:{self!r}" # start the registration early to speed up the start if self.config["transport_type"] == "udp": endpoint_registration_greenlet = gevent.spawn( self.discovery.register, self.address, self.config["transport"]["udp"]["external_ip"], self.config["transport"]["udp"]["external_port"], ) self.maybe_upgrade_db() storage = sqlite.SerializedSQLiteStorage( database_path=self.database_path, serializer=serialize.JSONSerializer()) storage.update_version() storage.log_run() self.wal = wal.restore_to_state_change( transition_function=node.state_transition, storage=storage, state_change_identifier="latest", ) if self.wal.state_manager.current_state is None: log.debug("No recoverable state available, creating inital state.", node=pex(self.address)) # On first run Raiden needs to fetch all events for the payment # network, to reconstruct all token network graphs and find opened # channels last_log_block_number = self.query_start_block last_log_block_hash = self.chain.client.blockhash_from_blocknumber( last_log_block_number) state_change = ActionInitChain( pseudo_random_generator=random.Random(), block_number=last_log_block_number, block_hash=last_log_block_hash, our_address=self.chain.node_address, chain_id=self.chain.network_id, ) self.handle_and_track_state_change(state_change) payment_network = PaymentNetworkState( self.default_registry.address, [], # empty list of token network states as it's the node's startup ) state_change = ContractReceiveNewPaymentNetwork( transaction_hash=constants.EMPTY_HASH, payment_network=payment_network, block_number=last_log_block_number, block_hash=last_log_block_hash, ) self.handle_and_track_state_change(state_change) else: # The `Block` state change is dispatched only after all the events # for that given block have been processed, filters can be safely # installed starting from this position without losing events. last_log_block_number = views.block_number( self.wal.state_manager.current_state) log.debug( "Restored state from WAL", last_restored_block=last_log_block_number, node=pex(self.address), ) known_networks = views.get_payment_network_identifiers( views.state_from_raiden(self)) if known_networks and self.default_registry.address not in known_networks: configured_registry = pex(self.default_registry.address) known_registries = lpex(known_networks) raise RuntimeError( f"Token network address mismatch.\n" f"Raiden is configured to use the smart contract " f"{configured_registry}, which conflicts with the current known " f"smart contracts {known_registries}") # Restore the current snapshot group state_change_qty = self.wal.storage.count_state_changes() self.snapshot_group = state_change_qty // SNAPSHOT_STATE_CHANGES_COUNT # Install the filters using the latest confirmed from_block value, # otherwise blockchain logs can be lost. self.install_all_blockchain_filters(self.default_registry, self.default_secret_registry, last_log_block_number) # Complete the first_run of the alarm task and synchronize with the # blockchain since the last run. # # Notes about setup order: # - The filters must be polled after the node state has been primed, # otherwise the state changes won't have effect. # - The alarm must complete its first run before the transport is started, # to reject messages for closed/settled channels. self.alarm.register_callback(self._callback_new_block) self.alarm.first_run(last_log_block_number) chain_state = views.state_from_raiden(self) self._initialize_payment_statuses(chain_state) self._initialize_transactions_queues(chain_state) self._initialize_messages_queues(chain_state) self._initialize_whitelists(chain_state) self._initialize_monitoring_services_queue(chain_state) self._initialize_ready_to_processed_events() if self.config["transport_type"] == "udp": endpoint_registration_greenlet.get( ) # re-raise if exception occurred # Start the side-effects: # - React to blockchain events # - React to incoming messages # - Send pending transactions # - Send pending message self.alarm.link_exception(self.on_error) self.transport.link_exception(self.on_error) self._start_transport(chain_state) self._start_alarm_task() log.debug("Raiden Service started", node=pex(self.address)) super().start() def _run(self, *args, **kwargs): # pylint: disable=method-hidden """ Busy-wait on long-lived subtasks/greenlets, re-raise if any error occurs """ self.greenlet.name = f"RaidenService._run node:{pex(self.address)}" try: self.stop_event.wait() except gevent.GreenletExit: # killed without exception self.stop_event.set() gevent.killall([self.alarm, self.transport]) # kill children raise # re-raise to keep killed status except Exception: self.stop() raise def stop(self): """ Stop the node gracefully. Raise if any stop-time error occurred on any subtask """ if self.stop_event.ready(): # not started return # Needs to come before any greenlets joining self.stop_event.set() # Filters must be uninstalled after the alarm task has stopped. Since # the events are polled by an alarm task callback, if the filters are # uninstalled before the alarm task is fully stopped the callback # `poll_blockchain_events` will fail. # # We need a timeout to prevent an endless loop from trying to # contact the disconnected client self.transport.stop() self.alarm.stop() self.transport.join() self.alarm.join() self.blockchain_events.uninstall_all_event_listeners() # Close storage DB to release internal DB lock self.wal.storage.conn.close() if self.db_lock is not None: self.db_lock.release() log.debug("Raiden Service stopped", node=pex(self.address)) @property def confirmation_blocks(self): return self.config["blockchain"]["confirmation_blocks"] @property def privkey(self): return self.chain.client.privkey def add_pending_greenlet(self, greenlet: Greenlet): """ Ensures an error on the passed greenlet crashes self/main greenlet. """ def remove(_): self.greenlets.remove(greenlet) self.greenlets.append(greenlet) greenlet.link_exception(self.on_error) greenlet.link_value(remove) def __repr__(self): return f"<{self.__class__.__name__} node:{pex(self.address)}>" def _start_transport(self, chain_state: ChainState): """ Initialize the transport and related facilities. Note: The transport must not be started before the node has caught up with the blockchain through `AlarmTask.first_run()`. This synchronization includes the on-chain channel state and is necessary to reject new messages for closed channels. """ assert self.alarm.is_primed(), f"AlarmTask not primed. node:{self!r}" assert self.ready_to_process_events, f"Event procossing disable. node:{self!r}" self.transport.start( raiden_service=self, message_handler=self.message_handler, prev_auth_data=chain_state.last_transport_authdata, ) for neighbour in views.all_neighbour_nodes(chain_state): if neighbour != ConnectionManager.BOOTSTRAP_ADDR: self.start_health_check_for(neighbour) def _start_alarm_task(self): """Start the alarm task. Note: The alarm task must be started only when processing events is allowed, otherwise side-effects of blockchain events will be ignored. """ assert self.ready_to_process_events, f"Event procossing disable. node:{self!r}" self.alarm.start() def _initialize_ready_to_processed_events(self): assert not self.transport assert not self.alarm # This flag /must/ be set to true before the transport or the alarm task is started self.ready_to_process_events = True def get_block_number(self) -> BlockNumber: assert self.wal, f"WAL object not yet initialized. node:{self!r}" return views.block_number(self.wal.state_manager.current_state) def on_message(self, message: Message): self.message_handler.on_message(self, message) def handle_and_track_state_change(self, state_change: StateChange): """ Dispatch the state change and does not handle the exceptions. When the method is used the exceptions are tracked and re-raised in the raiden service thread. """ for greenlet in self.handle_state_change(state_change): self.add_pending_greenlet(greenlet) def handle_state_change(self, state_change: StateChange) -> List[Greenlet]: """ Dispatch the state change and return the processing threads. Use this for error reporting, failures in the returned greenlets, should be re-raised using `gevent.joinall` with `raise_error=True`. """ assert self.wal, f"WAL not restored. node:{self!r}" log.debug( "State change", node=pex(self.address), state_change=_redact_secret( serialize.JSONSerializer.serialize(state_change)), ) old_state = views.state_from_raiden(self) new_state, raiden_event_list = self.wal.log_and_dispatch(state_change) for changed_balance_proof in views.detect_balance_proof_change( old_state, new_state): update_services_from_balance_proof(self, new_state, changed_balance_proof) log.debug( "Raiden events", node=pex(self.address), raiden_events=[ _redact_secret(serialize.JSONSerializer.serialize(event)) for event in raiden_event_list ], ) greenlets: List[Greenlet] = list() if self.ready_to_process_events: for raiden_event in raiden_event_list: greenlets.append( self.handle_event(chain_state=new_state, raiden_event=raiden_event)) state_changes_count = self.wal.storage.count_state_changes() new_snapshot_group = state_changes_count // SNAPSHOT_STATE_CHANGES_COUNT if new_snapshot_group > self.snapshot_group: log.debug("Storing snapshot", snapshot_id=new_snapshot_group) self.wal.snapshot() self.snapshot_group = new_snapshot_group return greenlets def handle_event(self, chain_state: ChainState, raiden_event: RaidenEvent) -> Greenlet: """Spawn a new thread to handle a Raiden event. This will spawn a new greenlet to handle each event, which is important for two reasons: - Blockchain transactions can be queued without interfering with each other. - The calling thread is free to do more work. This is specially important for the AlarmTask thread, which will eventually cause the node to send transactions when a given Block is reached (e.g. registering a secret or settling a channel). Important: This is spawing a new greenlet for /each/ transaction. It's therefore /required/ that there is *NO* order among these. """ return gevent.spawn(self._handle_event, chain_state, raiden_event) def _handle_event(self, chain_state: ChainState, raiden_event: RaidenEvent): assert isinstance(chain_state, ChainState) assert isinstance(raiden_event, RaidenEvent) try: self.raiden_event_handler.on_raiden_event(raiden=self, chain_state=chain_state, event=raiden_event) except RaidenRecoverableError as e: log.error(str(e)) except InvalidDBData: raise except RaidenUnrecoverableError as e: log_unrecoverable = ( self.config["environment_type"] == Environment.PRODUCTION and not self.config["unrecoverable_error_should_crash"]) if log_unrecoverable: log.error(str(e)) else: raise def set_node_network_state(self, node_address: Address, network_state: str): state_change = ActionChangeNodeNetworkState(node_address, network_state) self.handle_and_track_state_change(state_change) def start_health_check_for(self, node_address: Address): """Start health checking `node_address`. This function is a noop during initialization, because health checking can be started as a side effect of some events (e.g. new channel). For these cases the healthcheck will be started by `start_neighbours_healthcheck`. """ if self.transport: self.transport.start_health_check(node_address) def _callback_new_block(self, latest_block: Dict): """Called once a new block is detected by the alarm task. Note: This should be called only once per block, otherwise there will be duplicated `Block` state changes in the log. Therefore this method should be called only once a new block is mined with the corresponding block data from the AlarmTask. """ # User facing APIs, which have on-chain side-effects, force polled the # blockchain to update the node's state. This force poll is used to # provide a consistent view to the user, e.g. a channel open call waits # for the transaction to be mined and force polled the event to update # the node's state. This pattern introduced a race with the alarm task # and the task which served the user request, because the events are # returned only once per filter. The lock below is to protect against # these races (introduced by the commit # 3686b3275ff7c0b669a6d5e2b34109c3bdf1921d) with self.event_poll_lock: latest_block_number = latest_block["number"] # Handle testing with private chains. The block number can be # smaller than confirmation_blocks confirmed_block_number = max( GENESIS_BLOCK_NUMBER, latest_block_number - self.config["blockchain"]["confirmation_blocks"], ) confirmed_block = self.chain.client.web3.eth.getBlock( confirmed_block_number) # These state changes will be procesed with a block_number which is # /larger/ than the ChainState's block_number. for event in self.blockchain_events.poll_blockchain_events( confirmed_block_number): on_blockchain_event(self, event) # On restart the Raiden node will re-create the filters with the # ethereum node. These filters will have the from_block set to the # value of the latest Block state change. To avoid missing events # the Block state change is dispatched only after all of the events # have been processed. # # This means on some corner cases a few events may be applied # twice, this will happen if the node crashed and some events have # been processed but the Block state change has not been # dispatched. state_change = Block( block_number=confirmed_block_number, gas_limit=confirmed_block["gasLimit"], block_hash=BlockHash(bytes(confirmed_block["hash"])), ) # Note: It's important to /not/ block here, because this function # can be called from the alarm task greenlet, which should not # starve. self.handle_and_track_state_change(state_change) def _initialize_transactions_queues(self, chain_state: ChainState): """Initialize the pending transaction queue from the previous run. Note: This will only send the transactions which don't have their side-effects applied. Transactions which another node may have sent already will be detected by the alarm task's first run and cleared from the queue (e.g. A monitoring service update transfer). """ assert self.alarm.is_primed(), f"AlarmTask not primed. node:{self!r}" pending_transactions = views.get_pending_transactions(chain_state) log.debug( "Processing pending transactions", num_pending_transactions=len(pending_transactions), node=pex(self.address), ) for transaction in pending_transactions: try: self.raiden_event_handler.on_raiden_event( raiden=self, chain_state=chain_state, event=transaction) except RaidenRecoverableError as e: log.error(str(e)) except InvalidDBData: raise except RaidenUnrecoverableError as e: log_unrecoverable = ( self.config["environment_type"] == Environment.PRODUCTION and not self.config["unrecoverable_error_should_crash"]) if log_unrecoverable: log.error(str(e)) else: raise def _initialize_payment_statuses(self, chain_state: ChainState): """ Re-initialize targets_to_identifiers_to_statuses. Restore the PaymentStatus for any pending payment. This is not tied to a specific protocol message but to the lifecycle of a payment, i.e. the status is re-created if a payment itself has not completed. """ with self.payment_identifier_lock: for task in chain_state.payment_mapping.secrethashes_to_task.values( ): if not isinstance(task, InitiatorTask): continue # Every transfer in the transfers_list must have the same target # and payment_identifier, so using the first transfer is # sufficient. initiator = next( iter(task.manager_state.initiator_transfers.values())) transfer = initiator.transfer transfer_description = initiator.transfer_description target = transfer.target identifier = transfer.payment_identifier balance_proof = transfer.balance_proof self.targets_to_identifiers_to_statuses[target][ identifier] = PaymentStatus( payment_identifier=identifier, amount=transfer_description.amount, token_network_identifier=TokenNetworkID( balance_proof.token_network_identifier), payment_done=AsyncResult(), ) def _initialize_messages_queues(self, chain_state: ChainState): """Initialize all the message queues with the transport. Note: All messages from the state queues must be pushed to the transport before it's started. This is necessary to avoid a race where the transport processes network messages too quickly, queueing new messages before any of the previous messages, resulting in new messages being out-of-order. The Alarm task must be started before this method is called, otherwise queues for channel closed while the node was offline won't be properly cleared. It is not bad but it is suboptimal. """ assert not self.transport, f"Transport is running. node:{self!r}" assert self.alarm.is_primed(), f"AlarmTask not primed. node:{self!r}" events_queues = views.get_all_messagequeues(chain_state) for queue_identifier, event_queue in events_queues.items(): self.start_health_check_for(queue_identifier.recipient) for event in event_queue: message = message_from_sendevent(event) self.sign(message) self.transport.send_async(queue_identifier, message) def _initialize_monitoring_services_queue(self, chain_state: ChainState): """Send the monitoring requests for all current balance proofs. Note: The node must always send the *received* balance proof to the monitoring service, *before* sending its own locked transfer forward. If the monitoring service is updated after, then the following can happen: For a transfer A-B-C where this node is B - B receives T1 from A and processes it - B forwards its T2 to C * B crashes (the monitoring service is not updated) For the above scenario, the monitoring service would not have the latest balance proof received by B from A available with the lock for T1, but C would. If the channel B-C is closed and B does not come back online in time, the funds for the lock L1 can be lost. During restarts the rationale from above has to be replicated. Because the initialization code *is not* the same as the event handler. This means the balance proof updates must be done prior to the processing of the message queues. """ msg = ( "Transport was started before the monitoring service queue was updated. " "This can lead to safety issue. node:{self!r}") assert not self.transport, msg msg = "The node state was not yet recovered, cant read balance proofs. node:{self!r}" assert self.wal, msg current_balance_proofs = views.detect_balance_proof_change( old_state=ChainState( pseudo_random_generator=chain_state.pseudo_random_generator, block_number=GENESIS_BLOCK_NUMBER, block_hash=constants.EMPTY_HASH, our_address=chain_state.our_address, chain_id=chain_state.chain_id, ), current_state=chain_state, ) for balance_proof in current_balance_proofs: update_services_from_balance_proof(self, chain_state, balance_proof) def _initialize_whitelists(self, chain_state: ChainState): """ Whitelist neighbors and mediated transfer targets on transport """ for neighbour in views.all_neighbour_nodes(chain_state): if neighbour == ConnectionManager.BOOTSTRAP_ADDR: continue self.transport.whitelist(neighbour) events_queues = views.get_all_messagequeues(chain_state) for event_queue in events_queues.values(): for event in event_queue: if isinstance(event, SendLockedTransfer): transfer = event.transfer if transfer.initiator == self.address: self.transport.whitelist(address=transfer.target) def sign(self, message: Message): """ Sign message inplace. """ if not isinstance(message, SignedMessage): raise ValueError("{} is not signable.".format(repr(message))) message.sign(self.signer) def install_all_blockchain_filters( self, token_network_registry_proxy: TokenNetworkRegistry, secret_registry_proxy: SecretRegistry, from_block: BlockNumber, ): with self.event_poll_lock: node_state = views.state_from_raiden(self) token_networks = views.get_token_network_identifiers( node_state, token_network_registry_proxy.address) self.blockchain_events.add_token_network_registry_listener( token_network_registry_proxy=token_network_registry_proxy, contract_manager=self.contract_manager, from_block=from_block, ) self.blockchain_events.add_secret_registry_listener( secret_registry_proxy=secret_registry_proxy, contract_manager=self.contract_manager, from_block=from_block, ) for token_network in token_networks: token_network_proxy = self.chain.token_network( TokenNetworkAddress(token_network)) self.blockchain_events.add_token_network_listener( token_network_proxy=token_network_proxy, contract_manager=self.contract_manager, from_block=from_block, ) def connection_manager_for_token_network( self, token_network_identifier: TokenNetworkID) -> ConnectionManager: if not is_binary_address(token_network_identifier): raise InvalidAddress("token address is not valid.") known_token_networks = views.get_token_network_identifiers( views.state_from_raiden(self), self.default_registry.address) if token_network_identifier not in known_token_networks: raise InvalidAddress("token is not registered.") manager = self.tokennetworkids_to_connectionmanagers.get( token_network_identifier) if manager is None: manager = ConnectionManager(self, token_network_identifier) self.tokennetworkids_to_connectionmanagers[ token_network_identifier] = manager return manager def mediated_transfer_async( self, token_network_identifier: TokenNetworkID, amount: PaymentAmount, target: TargetAddress, identifier: PaymentID, fee: FeeAmount = MEDIATION_FEE, secret: Secret = None, secrethash: SecretHash = None, ) -> PaymentStatus: """ Transfer `amount` between this node and `target`. This method will start an asynchronous transfer, the transfer might fail or succeed depending on a couple of factors: - Existence of a path that can be used, through the usage of direct or intermediary channels. - Network speed, making the transfer sufficiently fast so it doesn't expire. """ if secret is None: if secrethash is None: secret = random_secret() else: secret = EMPTY_SECRET payment_status = self.start_mediated_transfer_with_secret( token_network_identifier=token_network_identifier, amount=amount, fee=fee, target=target, identifier=identifier, secret=secret, secrethash=secrethash, ) return payment_status def start_mediated_transfer_with_secret( self, token_network_identifier: TokenNetworkID, amount: PaymentAmount, fee: FeeAmount, target: TargetAddress, identifier: PaymentID, secret: Secret, secrethash: SecretHash = None, ) -> PaymentStatus: if secrethash is None: secrethash = sha3(secret) elif secrethash != sha3(secret): raise InvalidSecretHash( "provided secret and secret_hash do not match.") if len(secret) != SECRET_LENGTH: raise InvalidSecret("secret of invalid length.") # We must check if the secret was registered against the latest block, # even if the block is forked away and the transaction that registers # the secret is removed from the blockchain. The rationale here is that # someone else does know the secret, regardless of the chain state, so # the node must not use it to start a payment. # # For this particular case, it's preferable to use `latest` instead of # having a specific block_hash, because it's preferable to know if the secret # was ever known, rather than having a consistent view of the blockchain. secret_registered = self.default_secret_registry.is_secret_registered( secrethash=secrethash, block_identifier="latest") if secret_registered: raise RaidenUnrecoverableError( f"Attempted to initiate a locked transfer with secrethash {pex(secrethash)}." f" That secret is already registered onchain.") self.start_health_check_for(Address(target)) if identifier is None: identifier = create_default_identifier() with self.payment_identifier_lock: payment_status = self.targets_to_identifiers_to_statuses[ target].get(identifier) if payment_status: payment_status_matches = payment_status.matches( token_network_identifier, amount) if not payment_status_matches: raise PaymentConflict( "Another payment with the same id is in flight") return payment_status payment_status = PaymentStatus( payment_identifier=identifier, amount=amount, token_network_identifier=token_network_identifier, payment_done=AsyncResult(), ) self.targets_to_identifiers_to_statuses[target][ identifier] = payment_status init_initiator_statechange = initiator_init( raiden=self, transfer_identifier=identifier, transfer_amount=amount, transfer_secret=secret, transfer_secrethash=secrethash, transfer_fee=fee, token_network_identifier=token_network_identifier, target_address=target, ) # Dispatch the state change even if there are no routes to create the # wal entry. self.handle_and_track_state_change(init_initiator_statechange) return payment_status def mediate_mediated_transfer(self, transfer: LockedTransfer): init_mediator_statechange = mediator_init(self, transfer) self.handle_and_track_state_change(init_mediator_statechange) def target_mediated_transfer(self, transfer: LockedTransfer): self.start_health_check_for(Address(transfer.initiator)) init_target_statechange = target_init(transfer) self.handle_and_track_state_change(init_target_statechange) def maybe_upgrade_db(self) -> None: manager = UpgradeManager(db_filename=self.database_path, raiden=self, web3=self.chain.client.web3) manager.run()
class UDPTransport(Runnable): UDP_MAX_MESSAGE_SIZE = 1200 def __init__(self, discovery, udpsocket, throttle_policy, config): super().__init__() # these values are initialized by the start method self.queueids_to_queues: typing.Dict self.raiden: RaidenService self.discovery = discovery self.config = config self.retry_interval = config['retry_interval'] self.retries_before_backoff = config['retries_before_backoff'] self.nat_keepalive_retries = config['nat_keepalive_retries'] self.nat_keepalive_timeout = config['nat_keepalive_timeout'] self.nat_invitation_timeout = config['nat_invitation_timeout'] self.event_stop = Event() self.event_stop.set() self.greenlets = list() self.addresses_events = dict() self.messageids_to_asyncresults = dict() # Maps the addresses to a dict with the latest nonce (using a dict # because python integers are immutable) self.nodeaddresses_to_nonces = dict() cache = cachetools.TTLCache( maxsize=50, ttl=CACHE_TTL, ) cache_wrapper = cachetools.cached(cache=cache) self.get_host_port = cache_wrapper(discovery.get) self.throttle_policy = throttle_policy self.server = DatagramServer(udpsocket, handle=self.receive) def start( self, raiden: RaidenService, message_handler: MessageHandler, ): if not self.event_stop.ready(): raise RuntimeError('UDPTransport started while running') self.event_stop.clear() self.raiden = raiden self.message_handler = message_handler self.queueids_to_queues = dict() # server.stop() clears the handle. Since this may be a restart the # handle must always be set self.server.set_handle(self.receive) self.server.start() super().start() def _run(self): """ Runnable main method, perform wait on long-running subtasks """ try: self.event_stop.wait() except gevent.GreenletExit: # killed without exception self.event_stop.set() gevent.killall(self.greenlets) # kill children raise # re-raise to keep killed status except Exception: self.stop() # ensure cleanup and wait on subtasks raise def stop(self): if self.event_stop.ready(): return # double call, happens on normal stop, ignore self.event_stop.set() # Stop handling incoming packets, but don't close the socket. The # socket can only be safely closed after all outgoing tasks are stopped self.server.stop_accepting() # Stop processing the outgoing queues gevent.wait(self.greenlets) # All outgoing tasks are stopped. Now it's safe to close the socket. At # this point there might be some incoming message being processed, # keeping the socket open is not useful for these. self.server.stop() # Calling `.close()` on a gevent socket doesn't actually close the underlying os socket # so we do that ourselves here. # See: https://github.com/gevent/gevent/blob/master/src/gevent/_socket2.py#L208 # and: https://groups.google.com/forum/#!msg/gevent/Ro8lRra3nH0/ZENgEXrr6M0J try: self.server._socket.close() # pylint: disable=protected-access except socket.error: pass # Set all the pending results to False for async_result in self.messageids_to_asyncresults.values(): async_result.set(False) def get_health_events(self, recipient): """ Starts a healthcheck task for `recipient` and returns a HealthEvents with locks to react on its current state. """ if recipient not in self.addresses_events: self.start_health_check(recipient) return self.addresses_events[recipient] def start_health_check(self, recipient): """ Starts a task for healthchecking `recipient` if there is not one yet. """ if recipient not in self.addresses_events: ping_nonce = self.nodeaddresses_to_nonces.setdefault( recipient, {'nonce': 0}, # HACK: Allows the task to mutate the object ) events = healthcheck.HealthEvents( event_healthy=Event(), event_unhealthy=Event(), ) self.addresses_events[recipient] = events greenlet_healthcheck = gevent.spawn( healthcheck.healthcheck, self, recipient, self.event_stop, events.event_healthy, events.event_unhealthy, self.nat_keepalive_retries, self.nat_keepalive_timeout, self.nat_invitation_timeout, ping_nonce, ) greenlet_healthcheck.name = f'Healthcheck for {pex(recipient)}' greenlet_healthcheck.link_exception(self.on_error) self.greenlets.append(greenlet_healthcheck) def init_queue_for( self, queue_identifier: QueueIdentifier, items: typing.List[QueueItem_T], ) -> Queue_T: """ Create the queue identified by the queue_identifier and initialize it with `items`. """ recipient = queue_identifier.recipient queue = self.queueids_to_queues.get(queue_identifier) assert queue is None queue = NotifyingQueue(items=items) self.queueids_to_queues[queue_identifier] = queue events = self.get_health_events(recipient) greenlet_queue = gevent.spawn( single_queue_send, self, recipient, queue, queue_identifier, self.event_stop, events.event_healthy, events.event_unhealthy, self.retries_before_backoff, self.retry_interval, self.retry_interval * 10, ) if queue_identifier.channel_identifier == CHANNEL_IDENTIFIER_GLOBAL_QUEUE: greenlet_queue.name = f'Queue for {pex(recipient)} - global' else: greenlet_queue.name = ( f'Queue for {pex(recipient)} - {queue_identifier.channel_identifier}' ) greenlet_queue.link_exception(self.on_error) self.greenlets.append(greenlet_queue) log.debug( 'new queue created for', node=pex(self.raiden.address), queue_identifier=queue_identifier, items_qty=len(items), ) return queue def get_queue_for( self, queue_identifier: QueueIdentifier, ) -> Queue_T: """ Return the queue identified by the given queue identifier. If the queue doesn't exist it will be instantiated. """ queue = self.queueids_to_queues.get(queue_identifier) if queue is None: items = () queue = self.init_queue_for(queue_identifier, items) return queue def send_async( self, queue_identifier: QueueIdentifier, message: 'Message', ): """ Send a new ordered message to recipient. Messages that use the same `queue_identifier` are ordered. """ recipient = queue_identifier.recipient if not is_binary_address(recipient): raise ValueError('Invalid address {}'.format(pex(recipient))) # These are not protocol messages, but transport specific messages if isinstance(message, (Delivered, Ping, Pong)): raise ValueError('Do not use send for {} messages'.format( message.__class__.__name__)) messagedata = message.encode() if len(messagedata) > self.UDP_MAX_MESSAGE_SIZE: raise ValueError( 'message size exceeds the maximum {}'.format( self.UDP_MAX_MESSAGE_SIZE), ) # message identifiers must be unique message_id = message.message_identifier # ignore duplicates if message_id not in self.messageids_to_asyncresults: self.messageids_to_asyncresults[message_id] = AsyncResult() queue = self.get_queue_for(queue_identifier) queue.put((messagedata, message_id)) assert queue.is_set() log.debug( 'Message queued', node=pex(self.raiden.address), queue_identifier=queue_identifier, queue_size=len(queue), message=message, ) def maybe_send(self, recipient: typing.Address, message: Message): """ Send message to recipient if the transport is running. """ if not is_binary_address(recipient): raise InvalidAddress('Invalid address {}'.format(pex(recipient))) messagedata = message.encode() host_port = self.get_host_port(recipient) self.maybe_sendraw(host_port, messagedata) def maybe_sendraw_with_result( self, recipient: typing.Address, messagedata: bytes, message_id: typing.MessageID, ) -> AsyncResult: """ Send message to recipient if the transport is running. Returns: An AsyncResult that will be set once the message is delivered. As long as the message has not been acknowledged with a Delivered message the function will return the same AsyncResult. """ async_result = self.messageids_to_asyncresults.get(message_id) if async_result is None: async_result = AsyncResult() self.messageids_to_asyncresults[message_id] = async_result host_port = self.get_host_port(recipient) self.maybe_sendraw(host_port, messagedata) return async_result def maybe_sendraw(self, host_port: typing.Tuple[int, int], messagedata: bytes): """ Send message to recipient if the transport is running. """ # Don't sleep if timeout is zero, otherwise a context-switch is done # and the message is delayed, increasing it's latency sleep_timeout = self.throttle_policy.consume(1) if sleep_timeout: gevent.sleep(sleep_timeout) # Check the udp socket is still available before trying to send the # message. There must be *no context-switches after this test*. if hasattr(self.server, 'socket'): self.server.sendto( messagedata, host_port, ) def receive( self, messagedata: bytes, host_port: typing.Tuple[str, int], # pylint: disable=unused-argument ) -> bool: """ Handle an UDP packet. """ # pylint: disable=unidiomatic-typecheck if len(messagedata) > self.UDP_MAX_MESSAGE_SIZE: log.warning( 'Invalid message: Packet larger than maximum size', node=pex(self.raiden.address), message=hexlify(messagedata), length=len(messagedata), ) return False try: message = decode(messagedata) except InvalidProtocolMessage as e: log.warning( 'Invalid protocol message', error=str(e), node=pex(self.raiden.address), message=hexlify(messagedata), ) return False if type(message) == Pong: self.receive_pong(message) elif type(message) == Ping: self.receive_ping(message) elif type(message) == Delivered: self.receive_delivered(message) elif message is not None: self.receive_message(message) else: log.warning( 'Invalid message: Unknown cmdid', node=pex(self.raiden.address), message=hexlify(messagedata), ) return False return True def receive_message(self, message: Message): """ Handle a Raiden protocol message. The protocol requires durability of the messages. The UDP transport relies on the node's WAL for durability. The message will be converted to a state change, saved to the WAL, and *processed* before the durability is confirmed, which is a stronger property than what is required of any transport. """ self.message_handler.on_message(self.raiden, message) # Sending Delivered after the message is decoded and *processed* # gives a stronger guarantee than what is required from a # transport. # # Alternatives are, from weakest to strongest options: # - Just save it on disk and asynchronously process the messages # - Decode it, save to the WAL, and asynchronously process the # state change # - Decode it, save to the WAL, and process it (the current # implementation) delivered_message = Delivered(message.message_identifier) self.raiden.sign(delivered_message) self.maybe_send( message.sender, delivered_message, ) def receive_delivered(self, delivered: Delivered): """ Handle a Delivered message. The Delivered message is how the UDP transport guarantees persistence by the partner node. The message itself is not part of the raiden protocol, but it's required by this transport to provide the required properties. """ self.message_handler.on_message(self.raiden, delivered) message_id = delivered.delivered_message_identifier async_result = self.raiden.transport.messageids_to_asyncresults.get( message_id) # clear the async result, otherwise we have a memory leak if async_result is not None: del self.messageids_to_asyncresults[message_id] async_result.set() else: log.warn( 'Unknown delivered message received', message_id=message_id, ) # Pings and Pongs are used to check the health status of another node. They # are /not/ part of the raiden protocol, only part of the UDP transport, # therefore these messages are not forwarded to the message handler. def receive_ping(self, ping: Ping): """ Handle a Ping message by answering with a Pong. """ log_healthcheck.debug( 'Ping received', node=pex(self.raiden.address), message_id=ping.nonce, message=ping, sender=pex(ping.sender), ) pong = Pong(ping.nonce) self.raiden.sign(pong) try: self.maybe_send(ping.sender, pong) except (InvalidAddress, UnknownAddress) as e: log.debug("Couldn't send the `Delivered` message", e=e) def receive_pong(self, pong: Pong): """ Handles a Pong message. """ message_id = ('ping', pong.nonce, pong.sender) async_result = self.messageids_to_asyncresults.get(message_id) if async_result is not None: log_healthcheck.debug( 'Pong received', node=pex(self.raiden.address), sender=pex(pong.sender), message_id=pong.nonce, ) async_result.set(True) else: log_healthcheck.warn( 'Unknown pong received', message_id=message_id, ) def get_ping(self, nonce: int) -> Ping: """ Returns a signed Ping message. Note: Ping messages don't have an enforced ordering, so a Ping message with a higher nonce may be acknowledged first. """ message = Ping( nonce=nonce, current_protocol_version=constants.PROTOCOL_VERSION, ) self.raiden.sign(message) message_data = message.encode() return message_data def set_node_network_state(self, node_address: typing.Address, node_state): state_change = ActionChangeNodeNetworkState(node_address, node_state) self.raiden.handle_state_change(state_change)
class Subscriber(object): __slots__ = ('_pub', '_queue', '_closed', '_replyfn') def __init__(self, pub): assert isinstance(pub, Publisher) self._pub = pub self._queue = Queue() self._closed = Event() pub.attach(self) def __len__(self): if self._queue: return self._queue.qsize() def __call__(self, msg): return self.send(msg) def __enter__(self): return self def __del__(self): self.close() def __exit__(self, exc_type, exc_val, exc_tb): self.close() def __iter__(self): if not self.closed: while self._queue is not None: msg = self.recv() if msg is None: break yield msg def recv(self): if self._queue: msg = self._queue.get() if msg is StopIteration: self._queue = None self.close() return None return msg def send(self, msg): if self.closed: # XXX: raise better exception raise RuntimeError("Closed") if msg is StopIteration: return self.close() self._queue.put_nowait(msg) @property def closed(self): return self._closed.ready() and self._queue is None def close(self): if self._queue is not None: self._queue.put(StopIteration) if not self.closed: self._pub.detach(self) self._closed.set() def datastream(self): return DataStream(self)
class RaidenService(Runnable): """ A Raiden node. """ def __init__( self, chain: BlockChainService, query_start_block: BlockNumber, default_registry: TokenNetworkRegistry, default_secret_registry: SecretRegistry, private_key_bin, transport, raiden_event_handler, message_handler, config, discovery=None, ): super().__init__() if not isinstance(private_key_bin, bytes) or len(private_key_bin) != 32: raise ValueError('invalid private_key') self.tokennetworkids_to_connectionmanagers = dict() self.targets_to_identifiers_to_statuses: StatusesDict = defaultdict(dict) self.chain: BlockChainService = chain self.default_registry = default_registry self.query_start_block = query_start_block self.default_secret_registry = default_secret_registry self.config = config self.privkey = private_key_bin self.address = privatekey_to_address(private_key_bin) self.discovery = discovery self.private_key = PrivateKey(private_key_bin) self.pubkey = self.private_key.public_key.format(compressed=False) self.transport = transport self.blockchain_events = BlockchainEvents() self.alarm = AlarmTask(chain) self.raiden_event_handler = raiden_event_handler self.message_handler = message_handler self.stop_event = Event() self.stop_event.set() # inits as stopped self.wal = None self.snapshot_group = 0 # This flag will be used to prevent the service from processing # state changes events until we know that pending transactions # have been dispatched. self.dispatch_events_lock = Semaphore(1) self.contract_manager = ContractManager(config['contracts_path']) self.database_path = config['database_path'] if self.database_path != ':memory:': database_dir = os.path.dirname(config['database_path']) os.makedirs(database_dir, exist_ok=True) self.database_dir = database_dir # Prevent concurrent access to the same db self.lock_file = os.path.join(self.database_dir, '.lock') self.db_lock = filelock.FileLock(self.lock_file) else: self.database_path = ':memory:' self.database_dir = None self.lock_file = None self.serialization_file = None self.db_lock = None self.event_poll_lock = gevent.lock.Semaphore() self.gas_reserve_lock = gevent.lock.Semaphore() self.payment_identifier_lock = gevent.lock.Semaphore() def start(self): """ Start the node synchronously. Raises directly if anything went wrong on startup """ if not self.stop_event.ready(): raise RuntimeError(f'{self!r} already started') self.stop_event.clear() if self.database_dir is not None: self.db_lock.acquire(timeout=0) assert self.db_lock.is_locked # start the registration early to speed up the start if self.config['transport_type'] == 'udp': endpoint_registration_greenlet = gevent.spawn( self.discovery.register, self.address, self.config['transport']['udp']['external_ip'], self.config['transport']['udp']['external_port'], ) storage = sqlite.SQLiteStorage(self.database_path, serialize.JSONSerializer()) self.wal = wal.restore_to_state_change( transition_function=node.state_transition, storage=storage, state_change_identifier='latest', ) if self.wal.state_manager.current_state is None: log.debug( 'No recoverable state available, created inital state', node=pex(self.address), ) # On first run Raiden needs to fetch all events for the payment # network, to reconstruct all token network graphs and find opened # channels last_log_block_number = self.query_start_block state_change = ActionInitChain( random.Random(), last_log_block_number, self.chain.node_address, self.chain.network_id, ) self.handle_state_change(state_change) payment_network = PaymentNetworkState( self.default_registry.address, [], # empty list of token network states as it's the node's startup ) state_change = ContractReceiveNewPaymentNetwork( constants.EMPTY_HASH, payment_network, last_log_block_number, ) self.handle_state_change(state_change) else: # The `Block` state change is dispatched only after all the events # for that given block have been processed, filters can be safely # installed starting from this position without losing events. last_log_block_number = views.block_number(self.wal.state_manager.current_state) log.debug( 'Restored state from WAL', last_restored_block=last_log_block_number, node=pex(self.address), ) known_networks = views.get_payment_network_identifiers(views.state_from_raiden(self)) if known_networks and self.default_registry.address not in known_networks: configured_registry = pex(self.default_registry.address) known_registries = lpex(known_networks) raise RuntimeError( f'Token network address mismatch.\n' f'Raiden is configured to use the smart contract ' f'{configured_registry}, which conflicts with the current known ' f'smart contracts {known_registries}', ) # Restore the current snapshot group state_change_qty = self.wal.storage.count_state_changes() self.snapshot_group = state_change_qty // SNAPSHOT_STATE_CHANGES_COUNT # Install the filters using the correct from_block value, otherwise # blockchain logs can be lost. self.install_all_blockchain_filters( self.default_registry, self.default_secret_registry, last_log_block_number, ) # Complete the first_run of the alarm task and synchronize with the # blockchain since the last run. # # Notes about setup order: # - The filters must be polled after the node state has been primed, # otherwise the state changes won't have effect. # - The alarm must complete its first run before the transport is started, # to reject messages for closed/settled channels. self.alarm.register_callback(self._callback_new_block) with self.dispatch_events_lock: self.alarm.first_run(last_log_block_number) chain_state = views.state_from_raiden(self) self._initialize_transactions_queues(chain_state) self._initialize_whitelists(chain_state) # send messages in queue before starting transport, # this is necessary to avoid a race where, if the transport is started # before the messages are queued, actions triggered by it can cause new # messages to be enqueued before these older ones self._initialize_messages_queues(chain_state) # The transport must not ever be started before the alarm task's # `first_run()` has been, because it's this method which synchronizes the # node with the blockchain, including the channel's state (if the channel # is closed on-chain new messages must be rejected, which will not be the # case if the node is not synchronized) self.transport.start( raiden_service=self, message_handler=self.message_handler, prev_auth_data=chain_state.last_transport_authdata, ) # First run has been called above! self.alarm.start() # exceptions on these subtasks should crash the app and bubble up self.alarm.link_exception(self.on_error) self.transport.link_exception(self.on_error) # Health check needs the transport layer self.start_neighbours_healthcheck(chain_state) if self.config['transport_type'] == 'udp': endpoint_registration_greenlet.get() # re-raise if exception occurred log.debug('Raiden Service started', node=pex(self.address)) super().start() def _run(self, *args, **kwargs): # pylint: disable=method-hidden """ Busy-wait on long-lived subtasks/greenlets, re-raise if any error occurs """ try: self.stop_event.wait() except gevent.GreenletExit: # killed without exception self.stop_event.set() gevent.killall([self.alarm, self.transport]) # kill children raise # re-raise to keep killed status except Exception: self.stop() raise def stop(self): """ Stop the node gracefully. Raise if any stop-time error occurred on any subtask """ if self.stop_event.ready(): # not started return # Needs to come before any greenlets joining self.stop_event.set() # Filters must be uninstalled after the alarm task has stopped. Since # the events are polled by an alarm task callback, if the filters are # uninstalled before the alarm task is fully stopped the callback # `poll_blockchain_events` will fail. # # We need a timeout to prevent an endless loop from trying to # contact the disconnected client self.transport.stop() self.alarm.stop() self.transport.join() self.alarm.join() self.blockchain_events.uninstall_all_event_listeners() if self.db_lock is not None: self.db_lock.release() log.debug('Raiden Service stopped', node=pex(self.address)) def add_pending_greenlet(self, greenlet: gevent.Greenlet): greenlet.link_exception(self.on_error) def __repr__(self): return '<{} {}>'.format(self.__class__.__name__, pex(self.address)) def start_neighbours_healthcheck(self, chain_state: ChainState): for neighbour in views.all_neighbour_nodes(chain_state): if neighbour != ConnectionManager.BOOTSTRAP_ADDR: self.start_health_check_for(neighbour) def get_block_number(self) -> BlockNumber: return views.block_number(self.wal.state_manager.current_state) def on_message(self, message: Message): self.message_handler.on_message(self, message) def handle_state_change(self, state_change: StateChange): log.debug( 'State change', node=pex(self.address), state_change=_redact_secret(serialize.JSONSerializer.serialize(state_change)), ) event_list = self.wal.log_and_dispatch(state_change) if self.dispatch_events_lock.locked(): return [] for event in event_list: log.debug( 'Raiden event', node=pex(self.address), raiden_event=_redact_secret(serialize.JSONSerializer.serialize(event)), ) try: self.raiden_event_handler.on_raiden_event( raiden=self, event=event, ) except RaidenRecoverableError as e: log.error(str(e)) except InvalidDBData: raise except RaidenUnrecoverableError as e: log_unrecoverable = ( self.config['environment_type'] == Environment.PRODUCTION and not self.config['unrecoverable_error_should_crash'] ) if log_unrecoverable: log.error(str(e)) else: raise # Take a snapshot every SNAPSHOT_STATE_CHANGES_COUNT # TODO: Gather more data about storage requirements # and update the value to specify how often we need # capturing a snapshot should take place new_snapshot_group = self.wal.storage.count_state_changes() // SNAPSHOT_STATE_CHANGES_COUNT if new_snapshot_group > self.snapshot_group: log.debug('Storing snapshot', snapshot_id=new_snapshot_group) self.wal.snapshot() self.snapshot_group = new_snapshot_group return event_list def set_node_network_state(self, node_address: Address, network_state: str): state_change = ActionChangeNodeNetworkState(node_address, network_state) self.handle_state_change(state_change) def start_health_check_for(self, node_address: Address): # This function is a noop during initialization. It can be called # through the alarm task while polling for new channel events. The # healthcheck will be started by self.start_neighbours_healthcheck() if self.transport: self.transport.start_health_check(node_address) def _callback_new_block(self, latest_block: Dict): """Called once a new block is detected by the alarm task. Note: This should be called only once per block, otherwise there will be duplicated `Block` state changes in the log. Therefore this method should be called only once a new block is mined with the corresponding block data from the AlarmTask. """ # User facing APIs, which have on-chain side-effects, force polled the # blockchain to update the node's state. This force poll is used to # provide a consistent view to the user, e.g. a channel open call waits # for the transaction to be mined and force polled the event to update # the node's state. This pattern introduced a race with the alarm task # and the task which served the user request, because the events are # returned only once per filter. The lock below is to protect against # these races (introduced by the commit # 3686b3275ff7c0b669a6d5e2b34109c3bdf1921d) with self.event_poll_lock: latest_block_number = latest_block['number'] confirmation_blocks = self.config['blockchain']['confirmation_blocks'] confirmed_block_number = latest_block_number - confirmation_blocks confirmed_block = self.chain.client.web3.eth.getBlock(confirmed_block_number) # handle testing private chains confirmed_block_number = max(GENESIS_BLOCK_NUMBER, confirmed_block_number) for event in self.blockchain_events.poll_blockchain_events(confirmed_block_number): # These state changes will be procesed with a block_number # which is /larger/ than the ChainState's block_number. on_blockchain_event(self, event) # On restart the Raiden node will re-create the filters with the # ethereum node. These filters will have the from_block set to the # value of the latest Block state change. To avoid missing events # the Block state change is dispatched only after all of the events # have been processed. # # This means on some corner cases a few events may be applied # twice, this will happen if the node crashed and some events have # been processed but the Block state change has not been # dispatched. state_change = Block( block_number=confirmed_block_number, gas_limit=confirmed_block['gasLimit'], block_hash=bytes(confirmed_block['hash']), ) self.handle_state_change(state_change) def _register_payment_status( self, target: TargetAddress, identifier: PaymentID, balance_proof: BalanceProofUnsignedState, ): with self.payment_identifier_lock: self.targets_to_identifiers_to_statuses[target][identifier] = PaymentStatus( payment_identifier=identifier, amount=balance_proof.transferred_amount, token_network_identifier=balance_proof.token_network_identifier, payment_done=AsyncResult(), ) def _initialize_transactions_queues(self, chain_state: ChainState): pending_transactions = views.get_pending_transactions(chain_state) log.debug( 'Processing pending transactions', num_pending_transactions=len(pending_transactions), node=pex(self.address), ) with self.dispatch_events_lock: for transaction in pending_transactions: try: self.raiden_event_handler.on_raiden_event(self, transaction) except RaidenRecoverableError as e: log.error(str(e)) except InvalidDBData: raise except RaidenUnrecoverableError as e: log_unrecoverable = ( self.config['environment_type'] == Environment.PRODUCTION and not self.config['unrecoverable_error_should_crash'] ) if log_unrecoverable: log.error(str(e)) else: raise def _initialize_messages_queues(self, chain_state: ChainState): """ Push the queues to the transport and populate targets_to_identifiers_to_statuses. """ events_queues = views.get_all_messagequeues(chain_state) for queue_identifier, event_queue in events_queues.items(): self.start_health_check_for(queue_identifier.recipient) for event in event_queue: is_initiator = ( type(event) == SendLockedTransfer and event.transfer.initiator == self.address ) if is_initiator: self._register_payment_status( target=event.transfer.target, identifier=event.transfer.payment_identifier, balance_proof=event.transfer.balance_proof, ) message = message_from_sendevent(event, self.address) self.sign(message) self.transport.send_async(queue_identifier, message) def _initialize_whitelists(self, chain_state: ChainState): """ Whitelist neighbors and mediated transfer targets on transport """ for neighbour in views.all_neighbour_nodes(chain_state): if neighbour == ConnectionManager.BOOTSTRAP_ADDR: continue self.transport.whitelist(neighbour) events_queues = views.get_all_messagequeues(chain_state) for event_queue in events_queues.values(): for event in event_queue: is_initiator = ( type(event) == SendLockedTransfer and event.transfer.initiator == self.address ) if is_initiator: self.transport.whitelist(address=event.transfer.target) def sign(self, message: Message): """ Sign message inplace. """ if not isinstance(message, SignedMessage): raise ValueError('{} is not signable.'.format(repr(message))) message.sign(self.private_key) def install_all_blockchain_filters( self, token_network_registry_proxy: TokenNetworkRegistry, secret_registry_proxy: SecretRegistry, from_block: BlockNumber, ): with self.event_poll_lock: node_state = views.state_from_raiden(self) token_networks = views.get_token_network_identifiers( node_state, token_network_registry_proxy.address, ) self.blockchain_events.add_token_network_registry_listener( token_network_registry_proxy=token_network_registry_proxy, contract_manager=self.contract_manager, from_block=from_block, ) self.blockchain_events.add_secret_registry_listener( secret_registry_proxy=secret_registry_proxy, contract_manager=self.contract_manager, from_block=from_block, ) for token_network in token_networks: token_network_proxy = self.chain.token_network( TokenNetworkAddress(token_network), ) self.blockchain_events.add_token_network_listener( token_network_proxy=token_network_proxy, contract_manager=self.contract_manager, from_block=from_block, ) def connection_manager_for_token_network( self, token_network_identifier: TokenNetworkID, ) -> ConnectionManager: if not is_binary_address(token_network_identifier): raise InvalidAddress('token address is not valid.') known_token_networks = views.get_token_network_identifiers( views.state_from_raiden(self), self.default_registry.address, ) if token_network_identifier not in known_token_networks: raise InvalidAddress('token is not registered.') manager = self.tokennetworkids_to_connectionmanagers.get(token_network_identifier) if manager is None: manager = ConnectionManager(self, token_network_identifier) self.tokennetworkids_to_connectionmanagers[token_network_identifier] = manager return manager def mediated_transfer_async( self, token_network_identifier: TokenNetworkID, amount: TokenAmount, target: TargetAddress, identifier: PaymentID, ) -> AsyncResult: """ Transfer `amount` between this node and `target`. This method will start an asynchronous transfer, the transfer might fail or succeed depending on a couple of factors: - Existence of a path that can be used, through the usage of direct or intermediary channels. - Network speed, making the transfer sufficiently fast so it doesn't expire. """ secret = random_secret() async_result = self.start_mediated_transfer_with_secret( token_network_identifier, amount, target, identifier, secret, ) return async_result def start_mediated_transfer_with_secret( self, token_network_identifier: TokenNetworkID, amount: TokenAmount, target: TargetAddress, identifier: PaymentID, secret: Secret, ) -> AsyncResult: secret_hash = sha3(secret) if self.default_secret_registry.check_registered(secret_hash): raise RaidenUnrecoverableError( f'Attempted to initiate a locked transfer with secrethash {pex(secret_hash)}.' f' That secret is already registered onchain.', ) self.start_health_check_for(Address(target)) if identifier is None: identifier = create_default_identifier() with self.payment_identifier_lock: payment_status = self.targets_to_identifiers_to_statuses[target].get(identifier) if payment_status: payment_status_matches = payment_status.matches( token_network_identifier, amount, ) if not payment_status_matches: raise PaymentConflict( 'Another payment with the same id is in flight', ) return payment_status.payment_done payment_status = PaymentStatus( payment_identifier=identifier, amount=amount, token_network_identifier=token_network_identifier, payment_done=AsyncResult(), ) self.targets_to_identifiers_to_statuses[target][identifier] = payment_status init_initiator_statechange = initiator_init( raiden=self, transfer_identifier=identifier, transfer_amount=amount, transfer_secret=secret, token_network_identifier=token_network_identifier, target_address=target, ) # Dispatch the state change even if there are no routes to create the # wal entry. self.handle_state_change(init_initiator_statechange) return payment_status.payment_done def mediate_mediated_transfer(self, transfer: LockedTransfer): init_mediator_statechange = mediator_init(self, transfer) self.handle_state_change(init_mediator_statechange) def target_mediated_transfer(self, transfer: LockedTransfer): self.start_health_check_for(transfer.initiator) init_target_statechange = target_init(transfer) self.handle_state_change(init_target_statechange)
class TransactionDummy(ATransaction): ping_timeout = 5 # sec result_timeout = 20 # sec def __init__(self, callback_url: str, local_timeout: float, ping_timeout=None, result_timeout=None): super().__init__(ObjectId()) self.callback_url = callback_url self.local_timeout = local_timeout self.ping_timeout = ping_timeout if ping_timeout is not None else TransactionDummy.ping_timeout # type: float self.result_timeout = result_timeout if result_timeout is not None else TransactionDummy.result_timeout # type: float self.key = sha256( bytes(str(self.id) + str(int(time.time() * 10**6) ^ randint(0, 2**20)), encoding="utf-8")).hexdigest() debug_SSE.event({ "event": "init", "t": datetime.now(), "data": { "callback_url": self.callback_url, "local_timeout": self.local_timeout * 1000, "result_timeout": self.result_timeout * 1000, "ping_timeout": self.ping_timeout * 1000, "key": self.key, "_id": self.id } }) # DEBUG init self._ping = Event() self.result = AsyncResult() self.ping_timeout_thread_obj = None # type: Greenlet self.result_thread_obj = None # type: Greenlet @g_async def _spawn(self): self.ping_timeout_thread_obj = self.ping_timeout_thread( ) # THREAD:1, loop # wait((self.ready_commit, self.fail), timeout=self.local_timeout) # BLOCK, timeout # wait((self.commit, self.fail)) # BLOCK @g_async def ping_timeout_thread(self): while not (self.done.ready() or self.fail.ready()): debug_SSE.event({ "event": "wait_ping", "t": datetime.now(), "data": None }) # DEBUG wait_ping w = wait((self._ping, self.done, self.fail), count=1, timeout=self.ping_timeout * 2) # BLOCK, timeout if not len(w): debug_SSE.event({ "event": "fail", "t": datetime.now(), "data": "ping timeout" }) # DEBUG ping timeout self.fail.set() # EMIT(fail) break if self._ping.ready(): debug_SSE.event({ "event": "ping", "t": datetime.now(), "data": None }) # DEBUG ping self._ping.clear() # EMIT(-ping) sleep() def do_work(self, resource): self.result_thread_obj = self.result_thread(resource) # THREAD:1 @g_async def result_thread(self, resource): sleep(self.result_timeout) # BLOCK, sleep if not (self.ready_commit.ready() or self.fail.ready()): self.result.set(resource) # EMIT(result) self.ready_commit.set() # EMIT(ready_commit) debug_SSE.event({ "event": "ready_commit", "t": datetime.now(), "data": None }) # DEBUG ready_commit data = {"key": self.key, "response": {"data": self.result.get()}} rp = requests.put(self.callback_url, headers={"Connection": "close"}, json=data, timeout=5) # else: # raise Exception("error during work") def ping(self) -> bool: if not (self.fail.ready() or self.done.ready()): self._ping.set() # EMIT(ping) return True return False @g_async def do_commit(self): if not self.fail.ready(): if self.ready_commit.ready() and self.result.ready(): self.commit.set() # EMIT(ping) debug_SSE.event({ "event": "commit", "t": datetime.now(), "data": None }) # DEBUG commit else: raise Exception("Error during commit") sleep(randint(self.ping_timeout - 2, self.ping_timeout + 2)) data = {"key": self.key, "done": True} rp = requests.put(self.callback_url, headers={"Connection": "close"}, json=data) debug_SSE.event({ "event": "done", "t": datetime.now(), "data": None }) # DEBUG done @g_async def do_rollback(self): self.fail.set() # EMIT(fail) debug_SSE.event({ "event": "rollback", "t": datetime.now(), "data": None }) # DEBUG rollback
class Process(object): # TODO: handle bot stdout and stderr # TODO: refactor into TTY, Process and TTYProcess? def __init__(self, args, env=None, executable=None, shell=False): master, slave = pty.openpty() fcntl.fcntl(master, fcntl.F_SETFL, os.O_NONBLOCK) self._finished = Event() self._master = master self._read_event = get_hub().loop.io(master, 1) self._write_event = get_hub().loop.io(master, 2) self._args = args self._proc = Popen(args, env=env, executable=executable, shell=shell, stdin=slave, stdout=slave, stderr=slave, bufsize=0, universal_newlines=False, close_fds=True) def __repr__(self): return "Process:%x %r" % (id(self), self._args) @property def finished(self): return self._finished.ready() def _waitclosed(self): self._proc.wait() self.stop() def _writer(self, inch): """ This greenlet will block until messages are ready to be written to pty """ try: sock = self._master for msg in inch.watch(): if 'resize' in msg: set_winsize(sock, msg['resize']['width'], msg['resize']['height']) if 'data' in msg: buf = msg['data'] while not self.finished and len(buf): try: wait(self._write_event) except Exception: break nwritten = os.write(sock, msg['data']) buf = buf[nwritten:] except Exception: LOG.exception("In Process._writer") def run(self, task): writer_task = gevent.spawn(self._writer, task.input) gevent.spawn(self._waitclosed) proc = self._proc try: sock = self._master while not self.finished: try: wait(self._read_event) except Exception: break data = os.read(sock, 1024) if len(data) == 0 or data is StopIteration: break if sock == proc.stderr: task.output.send(dict(error=data)) else: task.output.send(dict(data=data)) except Exception: LOG.exception("While reading from process") finally: writer_task.kill() self.stop() def stop(self): if not self.finished: cancel_wait(self._read_event) cancel_wait(self._write_event) try: os.close(self._master) except Exception: pass if not self._proc.poll(): self._proc.terminate() self._proc.wait() self._finished.set()