class NATTraversalProtocol: PROTOCOL_ID = 0x400 PORT_PROBE_REQUEST = 1 PORT_PROBE_REPLY = 2 PORT_DUMMY = 3 on_probe_request = signal.Signal() on_probe_reply = signal.Signal() def __init__(self, session): self.session = session self.transport = session.transport self.handlers = { self.PORT_PROBE_REQUEST: self.handle_probe_request, self.PORT_PROBE_REPLY: self.handle_probe_reply } def handle(self, station, message): self.handlers[message.protocol_port](station, message.payload) def handle_probe_request(self, station, message): probe = NATProbeData.deserialize(message) self.on_probe_request(station, probe) def handle_probe_reply(self, station, message): probe = NATProbeData.deserialize(message) self.on_probe_reply(station, probe) def send_probe_request(self, station, count=1): logger.info("Sending NAT probe to %s", station.address) self.send_probe(station, NATProbeData.REQUEST, self.PORT_PROBE_REQUEST, count) def send_probe_reply(self, station, count=1): logger.info("Sending NAT probe reply to %s", station.address) self.send_probe(station, NATProbeData.REPLY, self.PORT_PROBE_REPLY, count) def send_probe(self, station, probe_type, protocol_port, count=1): for i in range(count): probe = NATProbeData(self.session.rvcid, probe_type, int(time.time())) message = PIAMessage() message.flags = 8 message.protocol_id = self.PROTOCOL_ID message.protocol_port = protocol_port message.payload = probe.serialize() self.transport.send(station, message)
def __init__(self, client, settings, sock=None): self.settings = settings self.transport_type = settings.get("prudp.transport") self.resend_timeout = settings.get("prudp.resend_timeout") self.resend_limit = settings.get("prudp.resend_limit") self.substreams = settings.get("prudp.substreams") self.failure = signal.Signal() self.sock = sock if not self.sock: if self.transport_type == settings.TRANSPORT_UDP: self.sock = socket.Socket(socket.TYPE_UDP) elif self.transport_type == settings.TRANSPORT_TCP: self.sock = socket.Socket(socket.TYPE_TCP) else: self.sock = websocket.WebSocketClient(True) if self.transport_type == settings.TRANSPORT_UDP: if settings.get("prudp.version") == 0: self.packet_encoder = PRUDPMessageV0(client, settings) else: self.packet_encoder = PRUDPMessageV1(client, settings) else: self.packet_encoder = PRUDPLiteMessage(client, settings) self.sequence_mgr = SequenceMgr(settings) self.message_encoder = PacketEncoder(settings) self.ack_events = {} self.ack_packets = {} self.socket_event = None self.packets = []
class MessageTransport: packet_received = signal.Signal() def __init__(self, session): self.session = session self.transport = PacketTransport(session) def start(self, address): self.transport.start(address) scheduler.add_socket(self.handle_recv, self.transport) def handle_recv(self, pair): station, packet = pair for message in packet.messages: self.packet_received(station, message) def send(self, station, message, add_mask=False): message.destination = 0 if station.index != 0xFD: message.destination = 1 << station.index message.station_key = self.session.rvcid message.station_index = self.session.station.index packet = PIAPacket([message]) self.transport.send(station, packet) def size_limit(self): return self.transport.size_limit()
class UnreliableProtocol: PROTOCOL_ID = 0x2000 message_received = signal.Signal() def __init__(self, session): self.session = session self.transport = session.transport def send(self, station, data): logger.debug("Sending %i bytes of unreliable data", len(data)) message = PIAMessage() message.flags = 0 message.protocol_id = self.PROTOCOL_ID message.protocol_port = 1 message.payload = data self.transport.send(station, message) def handle(self, station, message): if message.protocol_port == 1: logger.debug("Received %i bytes of unreliable data", len(message.payload)) self.message_received(station, message.payload) else: logger.warning("Unknown UnreliableProtocol port: %i", message.protocol_port)
class JoinResponseDecoder: finished = signal.Signal() def __init__(self): self.reset() def reset(self): self.station = None def parse(self, station, message): if self.station is None: self.update_info(station, message) elif not self.check_info(station, message): logger.warning("Incompatible join response fragment received") self.reset() self.update_info(station, message) fragment_index = message[5] if not self.fragments_received[fragment_index]: self.fragments_received[fragment_index] = True fragment_length = message[6] fragment_offs = message[7] offset = 8 for i in range(fragment_length): index = fragment_offs + i if self.infos[index]: logger.warning( "Overlapping join response fragments received") info = StationInfo.deserialize(message[offset:]) offset += StationInfo.sizeof() self.infos[index] = info if all(self.infos): self.finished(station, self.host_index, self.assigned_index, self.infos) self.reset() def update_info(self, station, message): self.station = station self.num_stations = message[1] self.host_index = message[2] self.assigned_index = message[3] self.num_fragments = message[4] self.fragments_received = [False] * self.num_fragments self.infos = [None] * self.num_stations def check_info(self, station, message): return self.station == station and \ self.num_stations == message[1] and \ self.host_index == message[2] and \ self.assigned_index == message[3] and \ self.num_fragments == message[4]
class KeepAliveProtocol: PROTOCOL_ID = 0xC0 on_receive = signal.Signal() def __init__(self, session): self.session = session self.transport = session.transport def send(self, station): logger.debug("Sending keep alive packet") message = PIAMessage() message.flags = 0 message.protocol_id = self.PROTOCOL_ID message.protocol_port = 0 message.payload = b"" self.transport.send(station, message) def handle(self, station, message): self.on_receive(station)
class NATTraversalHandler: initiate_probe = signal.Signal()
class NATTraversalMgr: nat_traversal_finished = signal.Signal() def __init__(self, session): self.backend = session.backend self.protocol = session.nat_protocol self.protocol.on_probe_request.add(self.handle_probe_request) self.protocol.on_probe_reply.add(self.handle_probe_reply) server = session.backend.nat_traversal_server server.handler.initiate_probe.add(self.handle_initiate_probe) self.client = NATTraversalClient(session.backend) self.station_mgr = session.station_mgr self.past_traversals = {} def init_station(self, url): station = self.station_mgr.find_by_rvcid(url["RVCID"]) if station: station.address = url.get_address() else: station = self.station_mgr.create(url.get_address(), url["RVCID"]) return station def handle_probe_request(self, station, probe): logger.info("Received probe request (%i, %i)", probe.connection_id, probe.system_time) self.protocol.send_probe_reply(station) def handle_probe_reply(self, station, probe): logger.info("Received probe reply: (%i, %i)", probe.connection_id, probe.system_time) self.past_traversals[station.rvcid] = time.monotonic() self.nat_traversal_finished(station) def handle_initiate_probe(self, source): logger.info("Received probe initiation request for %s" %source) if source["probeinit"] == 1: self.request_probe_initiation(source) station = self.init_station(source) self.protocol.send_probe_request(station, 3) def request_probe_initiation(self, target): logger.info("Sending probe initiation request to %s" %target) if target["type"] == 0: source = self.backend.local_station else: source = self.backend.public_station source = source.copy() if target["probeinit"] == 1: source["probeinit"] = 0 else: source["probeinit"] = 1 self.init_station(target) self.client.request_probe_initiation_ext([target], source) def report_nat_properties(self, props): logger.info("Reporting NAT properties") self.client.report_nat_properties( props.nat_mapping, props.nat_filtering, props.rtt ) def start_nat_traversal(self, url): rvcid = url["RVCID"] if rvcid in self.past_traversals: if time.monotonic() - self.past_traversals[rvcid] < 30: station = self.station_mgr.find_by_rvcid(rvcid) self.nat_traversal_finished(station) return logger.info("Starting NAT traversal for %s" %url) target = url.copy() target["probeinit"] = 0 self.request_probe_initiation(target)
class NintendoNotificationHandler: process_notification_event = signal.Signal()
class MeshMgr: join_succeeded = signal.Signal() join_denied = signal.Signal() station_joined = signal.Signal() mesh_destroyed = signal.Signal() def __init__(self, session): self.session = session self.protocol = session.mesh_protocol self.protocol.on_join_request.add(self.handle_join_request) self.protocol.on_join_response.add(self.handle_join_response) self.protocol.on_join_denied.add(self.handle_join_denied) self.protocol.on_destroy_request.add(self.handle_destroy_request) self.protocol.on_destroy_response.add(self.handle_destroy_response) self.station_mgr = session.station_mgr self.station_mgr.station_connected.add(self.handle_station_connected) self.stations = StationList() self.host_index = None self.update_counter = -1 self.expecting_join_response = False self.pending_connect = {} def is_host(self): return self.session.station.index == self.host_index def handle_join_request(self, station, station_index, station_addr): if self.is_host(): if station != self.station_mgr.find_by_address( station_addr.address): logger.warning( "Received join request with unexpected station address") self.protocol.send_deny_join(station, 2) else: self.send_join_response(station) self.send_update_mesh() self.station_joined(station) else: logger.warning("Received join request even though we aren't host") self.protocol.send_deny_join(station, 1) def handle_join_response(self, station, host_index, my_index, infos): if not self.expecting_join_response: logger.warning("Unexpected join response received") else: self.expecting_join_response = False self.host_index = host_index self.stations.add(self.session.station, my_index) for info in infos: rvcid = info.connection_info.public_station.rvcid self.pending_connect[rvcid] = info.index self.join_succeeded(infos) def handle_station_connected(self, station): if station.rvcid in self.pending_connect: index = self.pending_connect.pop(station.rvcid) if self.stations.is_usable(index): self.stations.add(station, index) self.protocol.assign_sliding_window(station) self.station_joined(station) else: logger.warning("Tried to assign station to occupied index") def handle_join_denied(self, station, reason): logger.info("Join denied (%i)" % reason) self.join_denied(station) def handle_destroy_request(self, station, station_index, station_address): self.protocol.send_destroy_response(station, self.session.station.index) self.mesh_destroyed() def handle_destroy_response(self, station, station_index): pass #TODO: Implement this later def send_join_response(self, station): index = self.stations.next_index() self.protocol.send_join_response(station, index, self.host_index, self.stations) self.stations.add(station) self.protocol.assign_sliding_window(station) def send_update_mesh(self): self.update_counter += 1 self.protocol.send_update_mesh(self.update_counter, self.host_index, self.stations) def create(self): self.stations.add(self.session.station) self.host_index = self.session.station.index def join(self, host_station): self.expecting_join_response = True self.protocol.send_join_request(host_station)
class MeshProtocol: PROTOCOL_ID = 0x200 PORT_UNRELIABLE = 0 PORT_RELIABLE = 1 MESSAGE_JOIN_REQUEST = 0x1 MESSAGE_JOIN_RESPONSE = 0x2 MESSAGE_LEAVE_REQUEST = 0x4 MESSAGE_LEAVE_RESPONSE = 0x8 MESSAGE_DESTROY_MESH = 0x10 MESSAGE_DESTROY_RESPONSE = 0x11 MESSAGE_UPDATE_MESH = 0x20 MESSAGE_KICKOUT_NOTICE = 0x21 MESSAGE_DUMMY = 0x22 MESSAGE_CONNECTION_FAILURE = 0x24 MESSAGE_GREETING = 0x40 MESSAGE_MIGRATION_FINISH = 0x41 MESSAGE_GREETING_RESPONSE = 0x42 MESSAGE_MIGRATION_START = 0x44 MESSAGE_MIGRATION_RESPONSE = 0x48 MESSAGE_MULTI_MIGRATION_START = 0x49 MESSAGE_MULTI_MIGRATION_RANK_DECISION = 0x4A MESSAGE_CONNECTION_REPORT = 0x80 MESSAGE_RELAY_ROUTE_DIRECTIONS = 0x81 on_join_request = signal.Signal() on_join_response = signal.Signal() on_join_denied = signal.Signal() on_leave_request = signal.Signal() on_leave_response = signal.Signal() on_destroy_request = signal.Signal() on_destroy_response = signal.Signal() def __init__(self, session): self.session = session self.transport = session.transport self.resender = session.resending_transport self.station_protocol = session.station_protocol self.handlers = { self.MESSAGE_JOIN_REQUEST: self.handle_join_request, self.MESSAGE_JOIN_RESPONSE: self.handle_join_response, self.MESSAGE_LEAVE_REQUEST: self.handle_leave_request, self.MESSAGE_LEAVE_RESPONSE: self.handle_leave_response, self.MESSAGE_DESTROY_MESH: self.handle_destroy_mesh, self.MESSAGE_DESTROY_RESPONSE: self.handle_destroy_response, self.MESSAGE_UPDATE_MESH: self.handle_update_mesh } self.sliding_windows = [None] * 32 self.join_response_decoder = JoinResponseDecoder() self.join_response_decoder.finished.add(self.on_join_response) def assign_sliding_window(self, station): self.sliding_windows[station.index] = ReliableTransport( self.transport, station, self.PROTOCOL_ID, self.PORT_RELIABLE, self.handle_message) def handle(self, station, message): if message.protocol_port == self.PORT_UNRELIABLE: self.handle_message(station, message.payload) elif message.protocol_port == self.PORT_RELIABLE: if station.index == 0xFD: logger.warning( "Received reliable mesh packet from unknown station") else: transport = self.sliding_windows[station.index] transport.handle(message) else: logger.warning("Unknown MeshProtocol port: %i", packet.protocol_port) def handle_message(self, station, message): message_type = message[0] self.handlers[message_type](station, message) def handle_join_request(self, station, message): logger.info("Received join request") station_address = StationAddress.deserialize(message[4:]) station_index = message[1] self.station_protocol.send_ack(station, message) self.on_join_request(station, station_index, station_address) def handle_join_response(self, station, message): logger.info("Received join response") if message[1] == 0: self.on_join_denied(station, message[4]) else: self.station_protocol.send_ack(station, message) self.join_response_decoder.parse(station, message) def handle_leave_request(self, station, message): logger.warning("TODO: Handle leave request") def handle_leave_response(self, station, message): logger.warning("TODO: Handle leave response") def handle_destroy_mesh(self, station, message): logger.info("Received destroy request") station_address = StationAddress.deserialize(message[4:]) station_index = message[1] self.on_destroy_request(station, station_index, station_address) def handle_destroy_response(self, station, message): logger.info("Received destroy response") station_index = message[1] self.on_destroy_response(station, station_index) def handle_update_mesh(self, station, message): logger.warning("TODO: Handle mesh update") def send_join_request(self, station): logger.info("Sending join request") data = bytes( [self.MESSAGE_JOIN_REQUEST, self.session.station.index, 0, 0]) data += self.session.station.station_address().serialize() self.send(station, data, 0, True) def send_join_response(self, station, assigned_index, host_index, stations): logger.info("Sending join response") infosize = (StationConnectionInfo.sizeof() + 4) & ~3 limit = self.transport.size_limit() - 0xC per_packet = limit // infosize fragments = (len(stations) + per_packet - 1) // per_packet for i in range(fragments): offset = i * per_packet remaining = len(stations) - offset num_infos = min(remaining, per_packet) data = bytes([ self.MESSAGE_JOIN_RESPONSE, len(stations), host_index, assigned_index, fragments, i, num_infos, offset ]) for j in range(num_infos): station_info = stations[offset + j] data += station_info.connection_info.serialize() data += bytes([station_info.index, 0]) self.send(station, data, 0, True) def send_deny_join(self, station, reason): logger.info("Denying join request") data = bytes([self.MESSAGE_JOIN_RESPONSE, 0, 0xFF, 0xFF, reason]) self.send(station, data, 0) self.send(station, data, 8) def send_destroy_response(self, station, station_index): logger.info("Sending destroy response") data = bytes([self.MESSAGE_DESTROY_RESPONSE, station_index]) self.send(station, data, 0) self.send(station, data, 8) def send_update_mesh(self, counter, host_index, stations): logger.info("Sending mesh update") data = struct.pack(">BBBBIBBBB", self.MESSAGE_UPDATE_MESH, len(stations), host_index, 0, counter, 1, 0, host_index, 0) for station in stations: data += station.connection_info.serialize() data += bytes([station.index, 0]) for reliable_transport in filter(None, self.sliding_windows): reliable_transport.send(data) def send(self, station, payload, flags, ack=False): message = PIAMessage() message.flags = flags message.protocol_id = self.PROTOCOL_ID message.protocol_port = self.PORT_UNRELIABLE message.payload = payload if ack: self.resender.send(station, message) else: self.transport.send(station, message)
class StationMgr: station_connected = signal.Signal() station_disconnected = signal.Signal() connection_denied = signal.Signal() def __init__(self, session): self.protocol = session.station_protocol self.protocol.on_connection_request.add(self.handle_connection_request) self.protocol.on_connection_response.add(self.handle_connection_response) self.protocol.on_connection_denied.add(self.handle_connection_denied) self.protocol.on_disconnection_request.add(self.handle_disconnection_request) self.protocol.on_disconnection_response.add(self.handle_disconnection_response) self.stations = StationTable() self.pending_connect = [] def handle_connection_request(self, station, connection_info, connection_id, is_inverse): if station != self.stations.find_by_connection_info(connection_info): logger.warning("Unexpected station connection info found in connection request") self.protocol.send_deny_connection(station, 1) return logger.info("Received connection request") station.connection_info = connection_info station.connection_id = connection_id if not is_inverse: self.protocol.send_connection_request(station, True) self.protocol.send_connection_response(station) def handle_connection_response(self, station, identification_info): if station in self.pending_connect: logger.info("Station connected: %s" %identification_info.name) self.pending_connect.remove(station) station.identification_info = identification_info station.is_connected = True self.station_connected(station) else: logger.debug("Unexpected connection response received: %s" %identification_info.name) def handle_connection_denied(self, station, reason): if station in self.pending_connect: logger.info("Received denying connection response (reason=%i)" %reason) self.pending_connect.remove(station) self.connection_denied(station) else: logger.warning("Unexpected denying connection response received") def handle_disconnection_request(self, station): logger.info("Received disconnection request") self.protocol.send_disconnection_response(station) if station.is_connected: station.is_connected = False self.station_disconnected(station) def handle_disconnection_response(self, station): if station.is_connected: logger.info("Received disconnection response") station.is_connected = False self.station_disconnected(station) else: logger.warning("Unexpected disconnection response received") def connect(self, station): if station.is_connected: self.station_connected(station) else: self.pending_connect.append(station) self.protocol.send_connection_request(station) def cancel_connection(self, station): if station in self.pending_connect: self.pending_connect.remove(station) def disconnect(self, station): if not station.is_connected: self.station_disconnected(station) else: self.protocol.send_disconnection_request(station) def create(self, address, rvcid): return self.stations.create(address, rvcid) def find_by_address(self, address): return self.stations.find_by_address(address) def find_by_connection_info(self, info): return self.stations.find_by_connection_info(info) def find_by_rvcid(self, rvcid): return self.stations.find_by_rvcid(rvcid)
class StationProtocol: PROTOCOL_ID = 0x100 PORT_UNRELIABLE = 0 PORT_RELIABLE = 1 MESSAGE_CONNECTION_REQUEST = 1 MESSAGE_CONNECTION_RESPONSE = 2 MESSAGE_DISCONNECTION_REQUEST = 3 MESSAGE_DISCONNECTION_RESPONSE = 4 MESSAGE_ACK = 5 MESSAGE_RELAY_CONNECTION_REQUEST = 6 MESSAGE_RELAY_CONNECTION_RESPONSE = 7 on_connection_request = signal.Signal() on_connection_response = signal.Signal() on_connection_denied = signal.Signal() on_disconnection_request = signal.Signal() on_disconnection_response = signal.Signal() def __init__(self, session): self.session = session self.transport = session.transport self.resender = session.resending_transport self.handlers = { self.MESSAGE_CONNECTION_REQUEST: self.handle_connection_request, self.MESSAGE_CONNECTION_RESPONSE: self.handle_connection_response, self.MESSAGE_DISCONNECTION_REQUEST: self.handle_disconnection_request, self.MESSAGE_DISCONNECTION_RESPONSE: self.handle_disconnection_response, self.MESSAGE_ACK: self.handle_ack } self.inverse_requests = {} self.connection_responses = {} def handle(self, station, message): if message.protocol_port == self.PORT_UNRELIABLE: message_type = message.payload[0] self.handlers[message_type](station, message.payload) else: logger.warning("Only unreliable station protocol is supported") def handle_connection_request(self, station, message): if message[2] != 3: logger.warning("Unsupported version number found in connection request") self.send_deny_connection(station, 2) return connection_info = StationConnectionInfo.deserialize(message[4:]) connection_id = message[1] is_inverse = message[3] self.send_ack(station, message) self.on_connection_request(station, connection_info, connection_id, is_inverse) def handle_connection_response(self, station, message): if message[1]: self.on_connection_denied(station, message[1]) else: identification_info = IdentificationInfo.deserialize(message[4:]) self.send_ack(station, message) self.on_connection_response(station, identification_info) def handle_disconnection_request(self, station, message): self.on_disconnection_request(station) def handle_disconnection_response(self, station, message): self.on_disconnection_response(station) def handle_ack(self, station, message): self.resender.handle_ack(message) def send_connection_request(self, station, is_inverse=False): if is_inverse: logger.info("Sending inverse connection request") else: logger.info("Sending connection request") data = bytes([self.MESSAGE_CONNECTION_REQUEST, self.session.station.connection_id, 3, is_inverse]) data += self.session.station.connection_info.serialize() self.send(station, data, True) def send_connection_response(self, station): logger.info("Sending connection response") data = bytes([self.MESSAGE_CONNECTION_RESPONSE, 0, 3, 3]) data += self.session.station.identification_info.serialize() data += b"\0\0" #Alignment self.send(station, data, True) def send_deny_connection(self, station, reason): logger.info("Denying connection request") data = bytes([self.MESSAGE_CONNECTION_RESPONSE, reason, 3, 0]) self.send(station, data) def send_disconnection_request(self, station): logger.info("Sending disconnection request") data = bytes([self.MESSAGE_DISCONNECTION_REQUEST]) self.send(station, data) def send_disconnection_response(self, station): logger.info("Sending disconnection response") data = bytes([self.MESSAGE_DISCONNECTION_RESPONSE]) self.send(station, data) def send_ack(self, station, message): ack_id = struct.unpack_from(">I", message, -4)[0] logger.info("Acknowledging packet (%i)" %ack_id) data = struct.pack(">BxxxI", self.MESSAGE_ACK, ack_id) self.send(station, data) def send(self, station, payload, ack=False): message = PIAMessage() message.flags = 0 message.protocol_id = self.PROTOCOL_ID message.protocol_port = self.PORT_UNRELIABLE message.payload = payload if ack: self.resender.send(station, message) else: self.transport.send(station, message)
class MeshMgr: mesh_created = signal.Signal() mesh_destroyed = signal.Signal() station_joined = signal.Signal() station_left = signal.Signal() JOIN_OK = 0 JOIN_DENIED = 1 JOIN_WAITING = 2 JOIN_NONE = 3 def __init__(self, session): self.session = session self.protocol = session.mesh_protocol self.protocol.on_join_request.add(self.handle_join_request) self.protocol.on_join_response.add(self.handle_join_response) self.protocol.on_join_denied.add(self.handle_join_denied) self.protocol.on_leave_request.add(self.handle_leave_request) self.protocol.on_leave_response.add(self.handle_leave_response) self.protocol.on_destroy_request.add(self.handle_destroy_request) self.protocol.on_destroy_response.add(self.handle_destroy_response) self.protocol.on_mesh_update.add(self.handle_mesh_update) self.station_mgr = session.station_mgr self.connection_mgr = session.connection_mgr self.stations = StationList() self.host_index = None self.update_counter = -1 self.join_state = self.JOIN_NONE def is_host(self): return self.session.station.index == self.host_index def handle_join_request(self, station, station_index, station_addr): if self.is_host(): if station != self.station_mgr.find_by_address(station_addr.address): logger.warning("Received join request with unexpected station address") self.protocol.send_deny_join(station, 2) else: logger.info("Received join request") index = self.stations.next_index() self.protocol.send_join_response(station, index, self.host_index, self.stations) self.stations.add(station) self.protocol.assign_sliding_window(station) self.send_update_mesh() self.station_joined(station) else: logger.warning("Received join request even though we aren't host") self.protocol.send_deny_join(station, 1) def handle_join_response(self, station, host_index, my_index, infos): if self.join_state != self.JOIN_WAITING: logger.warning("Unexpected join response received") else: host_index = infos[host_index].index my_index = infos[my_index].index logger.info("Received join response: (%i, %i)" %(host_index, my_index)) self.join_state = self.JOIN_OK self.host_index = host_index self.stations.add(self.session.station, my_index) self.stations.add(station, host_index) self.protocol.assign_sliding_window(station) self.mesh_created(host_index, my_index) self.station_joined(station) def handle_join_denied(self, station, reason): logger.info("Join denied (%i)" %reason) self.join_state = self.JOIN_DENIED def handle_leave_request(self, station, station_index, station_address): if self.is_host(): logger.warning("TODO: Handle leave request") else: logger.warning("Unexpected leave request received") def handle_leave_response(self, station, station_index, station_address): logger.warning("Unexpected leave response received") def handle_destroy_request(self, station, station_index, station_address): if self.is_host(): logger.warning("Unexpected destroy request received") else: self.protocol.send_destroy_response(station, self.session.station.index) self.mesh_destroyed() def handle_destroy_response(self, station, station_index): logger.warning("Unexpected destroy response received") def handle_mesh_update(self, infos): disconnect_list = list(self.stations) connect_list = [] for info in infos: station = self.station_mgr.find_by_connection_info(info.connection_info) if station in disconnect_list: disconnect_list.remove(station) if station.index != info.index: logger.error("Station index changed unexpectedly (%i -> %i)", station.index, info.index) else: rvcid = info.connection_info.public_station.rvcid station = self.station_mgr.create(None, rvcid) self.stations.add(station, info.index) self.protocol.assign_sliding_window(station) if station.index < self.session.station.index: connect_list.append(info.connection_info) self.connection_mgr.connect(*connect_list) def handle_station_connected(self, station): if station.rvcid in self.pending_connect: index = self.pending_connect.pop(station.rvcid) if self.stations.is_usable(index): self.stations.add(station, index) self.protocol.assign_sliding_window(station) self.station_joined(station) else: logger.warning("Tried to assign station to occupied index") def send_update_mesh(self): self.update_counter += 1 self.protocol.send_update_mesh( self.update_counter, self.host_index, self.stations ) def create(self): self.stations.add(self.session.station) self.host_index = self.session.station.index def join(self, host_station): self.join_state = self.JOIN_WAITING self.protocol.send_join_request(host_station) while self.join_state == self.JOIN_WAITING: scheduler.update() if self.join_state == self.JOIN_DENIED: raise RuntimeError("Join request denied") logger.info("Wait until all stations are connected") all_connected = False while not all_connected: all_connected = True for station in self.stations: if not station.is_connected: all_connected = False scheduler.update() logger.info("Successfully joined a mesh!")