class NATTraversalProtocol:

    PROTOCOL_ID = 0x400

    PORT_PROBE_REQUEST = 1
    PORT_PROBE_REPLY = 2
    PORT_DUMMY = 3

    on_probe_request = signal.Signal()
    on_probe_reply = signal.Signal()

    def __init__(self, session):
        self.session = session
        self.transport = session.transport

        self.handlers = {
            self.PORT_PROBE_REQUEST: self.handle_probe_request,
            self.PORT_PROBE_REPLY: self.handle_probe_reply
        }

    def handle(self, station, message):
        self.handlers[message.protocol_port](station, message.payload)

    def handle_probe_request(self, station, message):
        probe = NATProbeData.deserialize(message)
        self.on_probe_request(station, probe)

    def handle_probe_reply(self, station, message):
        probe = NATProbeData.deserialize(message)
        self.on_probe_reply(station, probe)

    def send_probe_request(self, station, count=1):
        logger.info("Sending NAT probe to %s", station.address)
        self.send_probe(station, NATProbeData.REQUEST, self.PORT_PROBE_REQUEST,
                        count)

    def send_probe_reply(self, station, count=1):
        logger.info("Sending NAT probe reply to %s", station.address)
        self.send_probe(station, NATProbeData.REPLY, self.PORT_PROBE_REPLY,
                        count)

    def send_probe(self, station, probe_type, protocol_port, count=1):
        for i in range(count):
            probe = NATProbeData(self.session.rvcid, probe_type,
                                 int(time.time()))
            message = PIAMessage()
            message.flags = 8
            message.protocol_id = self.PROTOCOL_ID
            message.protocol_port = protocol_port
            message.payload = probe.serialize()
            self.transport.send(station, message)
Example #2
0
    def __init__(self, client, settings, sock=None):
        self.settings = settings

        self.transport_type = settings.get("prudp.transport")
        self.resend_timeout = settings.get("prudp.resend_timeout")
        self.resend_limit = settings.get("prudp.resend_limit")
        self.substreams = settings.get("prudp.substreams")

        self.failure = signal.Signal()

        self.sock = sock
        if not self.sock:
            if self.transport_type == settings.TRANSPORT_UDP:
                self.sock = socket.Socket(socket.TYPE_UDP)
            elif self.transport_type == settings.TRANSPORT_TCP:
                self.sock = socket.Socket(socket.TYPE_TCP)
            else:
                self.sock = websocket.WebSocketClient(True)

        if self.transport_type == settings.TRANSPORT_UDP:
            if settings.get("prudp.version") == 0:
                self.packet_encoder = PRUDPMessageV0(client, settings)
            else:
                self.packet_encoder = PRUDPMessageV1(client, settings)
        else:
            self.packet_encoder = PRUDPLiteMessage(client, settings)

        self.sequence_mgr = SequenceMgr(settings)
        self.message_encoder = PacketEncoder(settings)

        self.ack_events = {}
        self.ack_packets = {}
        self.socket_event = None
        self.packets = []
Example #3
0
class MessageTransport:

	packet_received = signal.Signal()

	def __init__(self, session):
		self.session = session
		self.transport = PacketTransport(session)
		
	def start(self, address):
		self.transport.start(address)
		scheduler.add_socket(self.handle_recv, self.transport)
		
	def handle_recv(self, pair):
		station, packet = pair
		for message in packet.messages:
			self.packet_received(station, message)
		
	def send(self, station, message, add_mask=False):	
		message.destination = 0
		if station.index != 0xFD:
			message.destination = 1 << station.index
		message.station_key = self.session.rvcid
		message.station_index = self.session.station.index
		
		packet = PIAPacket([message])
		self.transport.send(station, packet)
		
	def size_limit(self):
		return self.transport.size_limit()
Example #4
0
class UnreliableProtocol:

    PROTOCOL_ID = 0x2000

    message_received = signal.Signal()

    def __init__(self, session):
        self.session = session
        self.transport = session.transport

    def send(self, station, data):
        logger.debug("Sending %i bytes of unreliable data", len(data))
        message = PIAMessage()
        message.flags = 0
        message.protocol_id = self.PROTOCOL_ID
        message.protocol_port = 1
        message.payload = data
        self.transport.send(station, message)

    def handle(self, station, message):
        if message.protocol_port == 1:
            logger.debug("Received %i bytes of unreliable data",
                         len(message.payload))
            self.message_received(station, message.payload)
        else:
            logger.warning("Unknown UnreliableProtocol port: %i",
                           message.protocol_port)
Example #5
0
class JoinResponseDecoder:

    finished = signal.Signal()

    def __init__(self):
        self.reset()

    def reset(self):
        self.station = None

    def parse(self, station, message):
        if self.station is None:
            self.update_info(station, message)
        elif not self.check_info(station, message):
            logger.warning("Incompatible join response fragment received")
            self.reset()
            self.update_info(station, message)

        fragment_index = message[5]
        if not self.fragments_received[fragment_index]:
            self.fragments_received[fragment_index] = True

            fragment_length = message[6]
            fragment_offs = message[7]

            offset = 8
            for i in range(fragment_length):
                index = fragment_offs + i

                if self.infos[index]:
                    logger.warning(
                        "Overlapping join response fragments received")

                info = StationInfo.deserialize(message[offset:])
                offset += StationInfo.sizeof()
                self.infos[index] = info

            if all(self.infos):
                self.finished(station, self.host_index, self.assigned_index,
                              self.infos)
                self.reset()

    def update_info(self, station, message):
        self.station = station
        self.num_stations = message[1]
        self.host_index = message[2]
        self.assigned_index = message[3]
        self.num_fragments = message[4]
        self.fragments_received = [False] * self.num_fragments
        self.infos = [None] * self.num_stations

    def check_info(self, station, message):
        return self.station == station and \
               self.num_stations == message[1] and \
            self.host_index == message[2] and \
            self.assigned_index == message[3] and \
            self.num_fragments == message[4]
Example #6
0
class KeepAliveProtocol:

    PROTOCOL_ID = 0xC0

    on_receive = signal.Signal()

    def __init__(self, session):
        self.session = session
        self.transport = session.transport

    def send(self, station):
        logger.debug("Sending keep alive packet")
        message = PIAMessage()
        message.flags = 0
        message.protocol_id = self.PROTOCOL_ID
        message.protocol_port = 0
        message.payload = b""
        self.transport.send(station, message)

    def handle(self, station, message):
        self.on_receive(station)
Example #7
0
class NATTraversalHandler:
	initiate_probe = signal.Signal()
Example #8
0
class NATTraversalMgr:

	nat_traversal_finished = signal.Signal()

	def __init__(self, session):
		self.backend = session.backend
	
		self.protocol = session.nat_protocol
		self.protocol.on_probe_request.add(self.handle_probe_request)
		self.protocol.on_probe_reply.add(self.handle_probe_reply)
		
		server = session.backend.nat_traversal_server
		server.handler.initiate_probe.add(self.handle_initiate_probe)
		self.client = NATTraversalClient(session.backend)
		
		self.station_mgr = session.station_mgr
		
		self.past_traversals = {}
		
	def init_station(self, url):
		station = self.station_mgr.find_by_rvcid(url["RVCID"])
		if station:
			station.address = url.get_address()
		else:
			station = self.station_mgr.create(url.get_address(), url["RVCID"])
		return station
	
	def handle_probe_request(self, station, probe):
		logger.info("Received probe request (%i, %i)", probe.connection_id, probe.system_time)
		self.protocol.send_probe_reply(station)
		
	def handle_probe_reply(self, station, probe):
		logger.info("Received probe reply: (%i, %i)", probe.connection_id, probe.system_time)
		self.past_traversals[station.rvcid] = time.monotonic()
		self.nat_traversal_finished(station)
		
	def handle_initiate_probe(self, source):
		logger.info("Received probe initiation request for %s" %source)
		if source["probeinit"] == 1:
			self.request_probe_initiation(source)
		station = self.init_station(source)
		self.protocol.send_probe_request(station, 3)
		
	def request_probe_initiation(self, target):
		logger.info("Sending probe initiation request to %s" %target)
		if target["type"] == 0:
			source = self.backend.local_station
		else:
			source = self.backend.public_station

		source = source.copy()
		if target["probeinit"] == 1:
			source["probeinit"] = 0
		else:
			source["probeinit"] = 1

		self.init_station(target)
		self.client.request_probe_initiation_ext([target], source)
		
	def report_nat_properties(self, props):
		logger.info("Reporting NAT properties")
		self.client.report_nat_properties(
			props.nat_mapping, props.nat_filtering, props.rtt
		)
		
	def start_nat_traversal(self, url):
		rvcid = url["RVCID"]
		if rvcid in self.past_traversals:
			if time.monotonic() - self.past_traversals[rvcid] < 30:
				station = self.station_mgr.find_by_rvcid(rvcid)
				self.nat_traversal_finished(station)
				return

		logger.info("Starting NAT traversal for %s" %url)
		target = url.copy()
		target["probeinit"] = 0
		self.request_probe_initiation(target)
Example #9
0
class NintendoNotificationHandler:
	process_notification_event = signal.Signal()
Example #10
0
class MeshMgr:

    join_succeeded = signal.Signal()
    join_denied = signal.Signal()
    station_joined = signal.Signal()

    mesh_destroyed = signal.Signal()

    def __init__(self, session):
        self.session = session
        self.protocol = session.mesh_protocol
        self.protocol.on_join_request.add(self.handle_join_request)
        self.protocol.on_join_response.add(self.handle_join_response)
        self.protocol.on_join_denied.add(self.handle_join_denied)
        self.protocol.on_destroy_request.add(self.handle_destroy_request)
        self.protocol.on_destroy_response.add(self.handle_destroy_response)

        self.station_mgr = session.station_mgr
        self.station_mgr.station_connected.add(self.handle_station_connected)

        self.stations = StationList()
        self.host_index = None

        self.update_counter = -1

        self.expecting_join_response = False

        self.pending_connect = {}

    def is_host(self):
        return self.session.station.index == self.host_index

    def handle_join_request(self, station, station_index, station_addr):
        if self.is_host():
            if station != self.station_mgr.find_by_address(
                    station_addr.address):
                logger.warning(
                    "Received join request with unexpected station address")
                self.protocol.send_deny_join(station, 2)
            else:
                self.send_join_response(station)
                self.send_update_mesh()
                self.station_joined(station)
        else:
            logger.warning("Received join request even though we aren't host")
            self.protocol.send_deny_join(station, 1)

    def handle_join_response(self, station, host_index, my_index, infos):
        if not self.expecting_join_response:
            logger.warning("Unexpected join response received")
        else:
            self.expecting_join_response = False
            self.host_index = host_index
            self.stations.add(self.session.station, my_index)
            for info in infos:
                rvcid = info.connection_info.public_station.rvcid
                self.pending_connect[rvcid] = info.index
            self.join_succeeded(infos)

    def handle_station_connected(self, station):
        if station.rvcid in self.pending_connect:
            index = self.pending_connect.pop(station.rvcid)
            if self.stations.is_usable(index):
                self.stations.add(station, index)
                self.protocol.assign_sliding_window(station)
                self.station_joined(station)
            else:
                logger.warning("Tried to assign station to occupied index")

    def handle_join_denied(self, station, reason):
        logger.info("Join denied (%i)" % reason)
        self.join_denied(station)

    def handle_destroy_request(self, station, station_index, station_address):
        self.protocol.send_destroy_response(station,
                                            self.session.station.index)
        self.mesh_destroyed()

    def handle_destroy_response(self, station, station_index):
        pass  #TODO: Implement this later

    def send_join_response(self, station):
        index = self.stations.next_index()
        self.protocol.send_join_response(station, index, self.host_index,
                                         self.stations)
        self.stations.add(station)
        self.protocol.assign_sliding_window(station)

    def send_update_mesh(self):
        self.update_counter += 1
        self.protocol.send_update_mesh(self.update_counter, self.host_index,
                                       self.stations)

    def create(self):
        self.stations.add(self.session.station)
        self.host_index = self.session.station.index

    def join(self, host_station):
        self.expecting_join_response = True
        self.protocol.send_join_request(host_station)
Example #11
0
class MeshProtocol:

    PROTOCOL_ID = 0x200

    PORT_UNRELIABLE = 0
    PORT_RELIABLE = 1

    MESSAGE_JOIN_REQUEST = 0x1
    MESSAGE_JOIN_RESPONSE = 0x2
    MESSAGE_LEAVE_REQUEST = 0x4
    MESSAGE_LEAVE_RESPONSE = 0x8
    MESSAGE_DESTROY_MESH = 0x10
    MESSAGE_DESTROY_RESPONSE = 0x11
    MESSAGE_UPDATE_MESH = 0x20
    MESSAGE_KICKOUT_NOTICE = 0x21
    MESSAGE_DUMMY = 0x22
    MESSAGE_CONNECTION_FAILURE = 0x24
    MESSAGE_GREETING = 0x40
    MESSAGE_MIGRATION_FINISH = 0x41
    MESSAGE_GREETING_RESPONSE = 0x42
    MESSAGE_MIGRATION_START = 0x44
    MESSAGE_MIGRATION_RESPONSE = 0x48
    MESSAGE_MULTI_MIGRATION_START = 0x49
    MESSAGE_MULTI_MIGRATION_RANK_DECISION = 0x4A
    MESSAGE_CONNECTION_REPORT = 0x80
    MESSAGE_RELAY_ROUTE_DIRECTIONS = 0x81

    on_join_request = signal.Signal()
    on_join_response = signal.Signal()
    on_join_denied = signal.Signal()

    on_leave_request = signal.Signal()
    on_leave_response = signal.Signal()

    on_destroy_request = signal.Signal()
    on_destroy_response = signal.Signal()

    def __init__(self, session):
        self.session = session
        self.transport = session.transport
        self.resender = session.resending_transport
        self.station_protocol = session.station_protocol

        self.handlers = {
            self.MESSAGE_JOIN_REQUEST: self.handle_join_request,
            self.MESSAGE_JOIN_RESPONSE: self.handle_join_response,
            self.MESSAGE_LEAVE_REQUEST: self.handle_leave_request,
            self.MESSAGE_LEAVE_RESPONSE: self.handle_leave_response,
            self.MESSAGE_DESTROY_MESH: self.handle_destroy_mesh,
            self.MESSAGE_DESTROY_RESPONSE: self.handle_destroy_response,
            self.MESSAGE_UPDATE_MESH: self.handle_update_mesh
        }

        self.sliding_windows = [None] * 32

        self.join_response_decoder = JoinResponseDecoder()
        self.join_response_decoder.finished.add(self.on_join_response)

    def assign_sliding_window(self, station):
        self.sliding_windows[station.index] = ReliableTransport(
            self.transport, station, self.PROTOCOL_ID, self.PORT_RELIABLE,
            self.handle_message)

    def handle(self, station, message):
        if message.protocol_port == self.PORT_UNRELIABLE:
            self.handle_message(station, message.payload)
        elif message.protocol_port == self.PORT_RELIABLE:
            if station.index == 0xFD:
                logger.warning(
                    "Received reliable mesh packet from unknown station")
            else:
                transport = self.sliding_windows[station.index]
                transport.handle(message)
        else:
            logger.warning("Unknown MeshProtocol port: %i",
                           packet.protocol_port)

    def handle_message(self, station, message):
        message_type = message[0]
        self.handlers[message_type](station, message)

    def handle_join_request(self, station, message):
        logger.info("Received join request")
        station_address = StationAddress.deserialize(message[4:])
        station_index = message[1]
        self.station_protocol.send_ack(station, message)
        self.on_join_request(station, station_index, station_address)

    def handle_join_response(self, station, message):
        logger.info("Received join response")
        if message[1] == 0:
            self.on_join_denied(station, message[4])
        else:
            self.station_protocol.send_ack(station, message)
            self.join_response_decoder.parse(station, message)

    def handle_leave_request(self, station, message):
        logger.warning("TODO: Handle leave request")

    def handle_leave_response(self, station, message):
        logger.warning("TODO: Handle leave response")

    def handle_destroy_mesh(self, station, message):
        logger.info("Received destroy request")
        station_address = StationAddress.deserialize(message[4:])
        station_index = message[1]
        self.on_destroy_request(station, station_index, station_address)

    def handle_destroy_response(self, station, message):
        logger.info("Received destroy response")
        station_index = message[1]
        self.on_destroy_response(station, station_index)

    def handle_update_mesh(self, station, message):
        logger.warning("TODO: Handle mesh update")

    def send_join_request(self, station):
        logger.info("Sending join request")

        data = bytes(
            [self.MESSAGE_JOIN_REQUEST, self.session.station.index, 0, 0])
        data += self.session.station.station_address().serialize()

        self.send(station, data, 0, True)

    def send_join_response(self, station, assigned_index, host_index,
                           stations):
        logger.info("Sending join response")

        infosize = (StationConnectionInfo.sizeof() + 4) & ~3
        limit = self.transport.size_limit() - 0xC

        per_packet = limit // infosize
        fragments = (len(stations) + per_packet - 1) // per_packet

        for i in range(fragments):
            offset = i * per_packet
            remaining = len(stations) - offset
            num_infos = min(remaining, per_packet)

            data = bytes([
                self.MESSAGE_JOIN_RESPONSE,
                len(stations), host_index, assigned_index, fragments, i,
                num_infos, offset
            ])

            for j in range(num_infos):
                station_info = stations[offset + j]
                data += station_info.connection_info.serialize()
                data += bytes([station_info.index, 0])
            self.send(station, data, 0, True)

    def send_deny_join(self, station, reason):
        logger.info("Denying join request")
        data = bytes([self.MESSAGE_JOIN_RESPONSE, 0, 0xFF, 0xFF, reason])
        self.send(station, data, 0)
        self.send(station, data, 8)

    def send_destroy_response(self, station, station_index):
        logger.info("Sending destroy response")
        data = bytes([self.MESSAGE_DESTROY_RESPONSE, station_index])
        self.send(station, data, 0)
        self.send(station, data, 8)

    def send_update_mesh(self, counter, host_index, stations):
        logger.info("Sending mesh update")

        data = struct.pack(">BBBBIBBBB", self.MESSAGE_UPDATE_MESH,
                           len(stations), host_index, 0, counter, 1, 0,
                           host_index, 0)
        for station in stations:
            data += station.connection_info.serialize()
            data += bytes([station.index, 0])

        for reliable_transport in filter(None, self.sliding_windows):
            reliable_transport.send(data)

    def send(self, station, payload, flags, ack=False):
        message = PIAMessage()
        message.flags = flags
        message.protocol_id = self.PROTOCOL_ID
        message.protocol_port = self.PORT_UNRELIABLE
        message.payload = payload
        if ack:
            self.resender.send(station, message)
        else:
            self.transport.send(station, message)
Example #12
0
class StationMgr:

	station_connected = signal.Signal()
	station_disconnected = signal.Signal()
	connection_denied = signal.Signal()

	def __init__(self, session):
		self.protocol = session.station_protocol
		self.protocol.on_connection_request.add(self.handle_connection_request)
		self.protocol.on_connection_response.add(self.handle_connection_response)
		self.protocol.on_connection_denied.add(self.handle_connection_denied)
		self.protocol.on_disconnection_request.add(self.handle_disconnection_request)
		self.protocol.on_disconnection_response.add(self.handle_disconnection_response)
		
		self.stations = StationTable()
		
		self.pending_connect = []
		
	def handle_connection_request(self, station, connection_info, connection_id, is_inverse):
		if station != self.stations.find_by_connection_info(connection_info):
			logger.warning("Unexpected station connection info found in connection request")
			self.protocol.send_deny_connection(station, 1)
			return
			
		logger.info("Received connection request")
		station.connection_info = connection_info
		station.connection_id = connection_id
		
		if not is_inverse:
			self.protocol.send_connection_request(station, True)

		self.protocol.send_connection_response(station)
			
	def handle_connection_response(self, station, identification_info):
		if station in self.pending_connect:
			logger.info("Station connected: %s" %identification_info.name)
			self.pending_connect.remove(station)
			station.identification_info = identification_info
			station.is_connected = True
			self.station_connected(station)
		else:
			logger.debug("Unexpected connection response received: %s" %identification_info.name)
			
	def handle_connection_denied(self, station, reason):
		if station in self.pending_connect:
			logger.info("Received denying connection response (reason=%i)" %reason)
			self.pending_connect.remove(station)
			self.connection_denied(station)
		else:
			logger.warning("Unexpected denying connection response received")
		
	def handle_disconnection_request(self, station):
		logger.info("Received disconnection request")
		self.protocol.send_disconnection_response(station)
		if station.is_connected:
			station.is_connected = False
			self.station_disconnected(station)
		
	def handle_disconnection_response(self, station):
		if station.is_connected:
			logger.info("Received disconnection response")
			station.is_connected = False
			self.station_disconnected(station)
		else:
			logger.warning("Unexpected disconnection response received")
		
	def connect(self, station):
		if station.is_connected:
			self.station_connected(station)
		else:
			self.pending_connect.append(station)
			self.protocol.send_connection_request(station)
			
	def cancel_connection(self, station):
		if station in self.pending_connect:
			self.pending_connect.remove(station)
			
	def disconnect(self, station):
		if not station.is_connected:
			self.station_disconnected(station)
		else:
			self.protocol.send_disconnection_request(station)
		
	def create(self, address, rvcid): return self.stations.create(address, rvcid)
	
	def find_by_address(self, address):
		return self.stations.find_by_address(address)
	def find_by_connection_info(self, info):
		return self.stations.find_by_connection_info(info)
	def find_by_rvcid(self, rvcid):
		return self.stations.find_by_rvcid(rvcid)
Example #13
0
class StationProtocol:

	PROTOCOL_ID = 0x100
	
	PORT_UNRELIABLE = 0
	PORT_RELIABLE = 1
	
	MESSAGE_CONNECTION_REQUEST = 1
	MESSAGE_CONNECTION_RESPONSE = 2
	MESSAGE_DISCONNECTION_REQUEST = 3
	MESSAGE_DISCONNECTION_RESPONSE = 4
	MESSAGE_ACK = 5
	MESSAGE_RELAY_CONNECTION_REQUEST = 6
	MESSAGE_RELAY_CONNECTION_RESPONSE = 7
	
	on_connection_request = signal.Signal()
	on_connection_response = signal.Signal()
	on_connection_denied = signal.Signal()
	
	on_disconnection_request = signal.Signal()
	on_disconnection_response = signal.Signal()
	
	def __init__(self, session):
		self.session = session
		self.transport = session.transport
		self.resender = session.resending_transport
		
		self.handlers = {
			self.MESSAGE_CONNECTION_REQUEST: self.handle_connection_request,
			self.MESSAGE_CONNECTION_RESPONSE: self.handle_connection_response,
			self.MESSAGE_DISCONNECTION_REQUEST: self.handle_disconnection_request,
			self.MESSAGE_DISCONNECTION_RESPONSE: self.handle_disconnection_response,
			self.MESSAGE_ACK: self.handle_ack
		}
		
		self.inverse_requests = {}
		self.connection_responses = {}
		
	def handle(self, station, message):
		if message.protocol_port == self.PORT_UNRELIABLE:
			message_type = message.payload[0]
			self.handlers[message_type](station, message.payload)
		else:
			logger.warning("Only unreliable station protocol is supported")
		
	def handle_connection_request(self, station, message):
		if message[2] != 3:
			logger.warning("Unsupported version number found in connection request")
			self.send_deny_connection(station, 2)
			return
		
		connection_info = StationConnectionInfo.deserialize(message[4:])
		connection_id = message[1]
		is_inverse = message[3]

		self.send_ack(station, message)
		self.on_connection_request(station, connection_info, connection_id, is_inverse)
		
	def handle_connection_response(self, station, message):
		if message[1]:
			self.on_connection_denied(station, message[1])
		else:
			identification_info = IdentificationInfo.deserialize(message[4:])
			self.send_ack(station, message)
			self.on_connection_response(station, identification_info)
			
	def handle_disconnection_request(self, station, message):
		self.on_disconnection_request(station)
		
	def handle_disconnection_response(self, station, message):
		self.on_disconnection_response(station)
		
	def handle_ack(self, station, message):
		self.resender.handle_ack(message)
		
	def send_connection_request(self, station, is_inverse=False):
		if is_inverse:
			logger.info("Sending inverse connection request")
		else:
			logger.info("Sending connection request")
		data = bytes([self.MESSAGE_CONNECTION_REQUEST, self.session.station.connection_id, 3, is_inverse])
		data += self.session.station.connection_info.serialize()
		self.send(station, data, True)
		
	def send_connection_response(self, station):
		logger.info("Sending connection response")
		data = bytes([self.MESSAGE_CONNECTION_RESPONSE, 0, 3, 3])
		data += self.session.station.identification_info.serialize()
		data += b"\0\0" #Alignment
		self.send(station, data, True)
		
	def send_deny_connection(self, station, reason):
		logger.info("Denying connection request")
		data = bytes([self.MESSAGE_CONNECTION_RESPONSE, reason, 3, 0])
		self.send(station, data)
		
	def send_disconnection_request(self, station):
		logger.info("Sending disconnection request")
		data = bytes([self.MESSAGE_DISCONNECTION_REQUEST])
		self.send(station, data)
		
	def send_disconnection_response(self, station):
		logger.info("Sending disconnection response")
		data = bytes([self.MESSAGE_DISCONNECTION_RESPONSE])
		self.send(station, data)
		
	def send_ack(self, station, message):
		ack_id = struct.unpack_from(">I", message, -4)[0]
		logger.info("Acknowledging packet (%i)" %ack_id)
		data = struct.pack(">BxxxI", self.MESSAGE_ACK, ack_id)
		self.send(station, data)
		
	def send(self, station, payload, ack=False):
		message = PIAMessage()
		message.flags = 0
		message.protocol_id = self.PROTOCOL_ID
		message.protocol_port = self.PORT_UNRELIABLE
		message.payload = payload
		if ack:
			self.resender.send(station, message)
		else:
			self.transport.send(station, message)
Example #14
0
class MeshMgr:
	
	mesh_created = signal.Signal()
	mesh_destroyed = signal.Signal()
	
	station_joined = signal.Signal()
	station_left = signal.Signal()
	
	JOIN_OK = 0
	JOIN_DENIED = 1
	JOIN_WAITING = 2
	JOIN_NONE = 3

	def __init__(self, session):
		self.session = session
		self.protocol = session.mesh_protocol
		self.protocol.on_join_request.add(self.handle_join_request)
		self.protocol.on_join_response.add(self.handle_join_response)
		self.protocol.on_join_denied.add(self.handle_join_denied)
		self.protocol.on_leave_request.add(self.handle_leave_request)
		self.protocol.on_leave_response.add(self.handle_leave_response)
		self.protocol.on_destroy_request.add(self.handle_destroy_request)
		self.protocol.on_destroy_response.add(self.handle_destroy_response)
		self.protocol.on_mesh_update.add(self.handle_mesh_update)
		
		self.station_mgr = session.station_mgr
		self.connection_mgr = session.connection_mgr
		
		self.stations = StationList()
		self.host_index = None
		
		self.update_counter = -1
		
		self.join_state = self.JOIN_NONE
		
	def is_host(self):
		return self.session.station.index == self.host_index
		
	def handle_join_request(self, station, station_index, station_addr):
		if self.is_host():
			if station != self.station_mgr.find_by_address(station_addr.address):
				logger.warning("Received join request with unexpected station address")
				self.protocol.send_deny_join(station, 2)
			else:
				logger.info("Received join request")
				index = self.stations.next_index()
				self.protocol.send_join_response(station, index, self.host_index, self.stations)
				self.stations.add(station)
				self.protocol.assign_sliding_window(station)
				self.send_update_mesh()
				self.station_joined(station)
		else:
			logger.warning("Received join request even though we aren't host")
			self.protocol.send_deny_join(station, 1)

	def handle_join_response(self, station, host_index, my_index, infos):
		if self.join_state != self.JOIN_WAITING:
			logger.warning("Unexpected join response received")
		else:
			host_index = infos[host_index].index
			my_index = infos[my_index].index
			logger.info("Received join response: (%i, %i)" %(host_index, my_index))
			self.join_state = self.JOIN_OK
			self.host_index = host_index
			self.stations.add(self.session.station, my_index)
			self.stations.add(station, host_index)
			self.protocol.assign_sliding_window(station)
			self.mesh_created(host_index, my_index)
			self.station_joined(station)
	
	def handle_join_denied(self, station, reason):
		logger.info("Join denied (%i)" %reason)
		self.join_state = self.JOIN_DENIED
		
	def handle_leave_request(self, station, station_index, station_address):
		if self.is_host():
			logger.warning("TODO: Handle leave request")
		else:
			logger.warning("Unexpected leave request received")
			
	def handle_leave_response(self, station, station_index, station_address):
		logger.warning("Unexpected leave response received")
		
	def handle_destroy_request(self, station, station_index, station_address):
		if self.is_host():
			logger.warning("Unexpected destroy request received")
		else:
			self.protocol.send_destroy_response(station, self.session.station.index)
			self.mesh_destroyed()
		
	def handle_destroy_response(self, station, station_index):
		logger.warning("Unexpected destroy response received")
		
	def handle_mesh_update(self, infos):
		disconnect_list = list(self.stations)
		connect_list = []
		for info in infos:
			station = self.station_mgr.find_by_connection_info(info.connection_info)
			if station in disconnect_list:
				disconnect_list.remove(station)
				if station.index != info.index:
					logger.error("Station index changed unexpectedly (%i -> %i)",
					             station.index, info.index)
			else:
				rvcid = info.connection_info.public_station.rvcid
				station = self.station_mgr.create(None, rvcid)
				self.stations.add(station, info.index)
				self.protocol.assign_sliding_window(station)
				if station.index < self.session.station.index:
					connect_list.append(info.connection_info)
					
		self.connection_mgr.connect(*connect_list)
		
	def handle_station_connected(self, station):
		if station.rvcid in self.pending_connect:
			index = self.pending_connect.pop(station.rvcid)
			if self.stations.is_usable(index):
				self.stations.add(station, index)
				self.protocol.assign_sliding_window(station)
				self.station_joined(station)
			else:
				logger.warning("Tried to assign station to occupied index")
		
	def send_update_mesh(self):
		self.update_counter += 1
		self.protocol.send_update_mesh(
			self.update_counter, self.host_index, self.stations
		)
	
	def create(self):
		self.stations.add(self.session.station)
		self.host_index = self.session.station.index
	
	def join(self, host_station):
		self.join_state = self.JOIN_WAITING
		self.protocol.send_join_request(host_station)
		while self.join_state == self.JOIN_WAITING:
			scheduler.update()
		if self.join_state == self.JOIN_DENIED:
			raise RuntimeError("Join request denied")
		
		logger.info("Wait until all stations are connected")
		all_connected = False
		while not all_connected:
			all_connected = True
			for station in self.stations:
				if not station.is_connected:
					all_connected = False
			scheduler.update()
		logger.info("Successfully joined a mesh!")