def send_advertisement(self): link_states = self.generate_our_adverts() advertisement = LinkStateAdvertisementHeader( node=self.our_address, name=self.host_name, epoch=self.our_link_state_epoch, link_states=link_states) self.debug("Sending Advertisement {}".format(advertisement)) network_header = NetworkHeader( version=0, qos=QoS.Lower, protocol=Protocol.LINK_STATE, ttl=7, identity=self.next_sequence(), length=0, source=self.our_address, destination=self.BroadcastAddress, ) stream = BytesIO() network_header.encode(stream) advertisement.encode(stream) stream.seek(0) buffer = stream.read() self.send(network_header, buffer)
def send_query(self, neighbor: MeshAddress): self.debug(f"Querying {neighbor} for link states") known_link_states = dict() for node, link_states in self.valid_link_states().items(): known_link_states[node] = self.link_state_epochs[node] query = LinkStateQueryHeader(node=self.our_address, epoch=self.our_link_state_epoch, link_nodes=list(known_link_states.keys()), link_epochs=list( known_link_states.values())) network_header = NetworkHeader( version=0, qos=QoS.Lower, protocol=Protocol.LINK_STATE_QUERY, ttl=1, identity=self.next_sequence(), length=0, source=self.our_address, destination=neighbor, ) stream = BytesIO() network_header.encode(stream) query.encode(stream) stream.seek(0) buffer = stream.read() self.send(network_header, buffer)
def test_encode_decode_fragment(self): msg = "Hello, World!".encode("utf-8") datagram1 = DatagramHeader(source=100, destination=100, length=len(msg), checksum=crc_b(msg)) fragment1 = FragmentHeader(Protocol.DATAGRAM, FragmentFlags.NONE, 0, 99) header1 = NetworkHeader(version=0, protocol=Protocol.FRAGMENT, qos=QoS.Default, ttl=4, identity=42, length=fragment1.size() + datagram1.size() + len(msg), source=MeshAddress(1), destination=MeshAddress(2)) stream = BytesIO() header1.encode(stream) fragment1.encode(stream) datagram1.encode(stream) stream.write(msg) stream.seek(0) header2 = NetworkHeader.decode(stream) fragment = FragmentHeader.decode(stream) datagram = DatagramHeader.decode(stream) msg2 = stream.read() self.assertEqual(header1, header2) self.assertEqual(datagram1, datagram) self.assertEqual(msg2, msg)
def encode_packet(network_header: NetworkHeader, additional_headers: List[Header], payload: bytes) -> bytes: stream = BytesIO() network_header.encode(stream) for header in additional_headers: header.encode(stream) stream.write(payload) stream.seek(0) return stream.read()
def fragment_datagram(self, source: MeshAddress, destination: MeshAddress, datagram_header: DatagramHeader, data: bytes): stream = BytesIO() datagram_header.encode(stream) stream.write(data) stream.seek(0) fragment_size = self.network.mtu() - FragmentHeader.size() fragments = list(chunks(stream.read(), fragment_size)) with self.send_seq: sequences = range(self.send_seq, len(fragments)) self.send_seq += len(fragments) for i, fragment in zip(range(len(fragments)), fragments): if i < len(fragments) - 1: flags = FragmentFlags.FRAGMENT else: flags = FragmentFlags.NONE fragment_header = FragmentHeader(protocol=Protocol.DATAGRAM, flags=flags, fragment=i, sequence=sequences[i]) network_header = NetworkHeader( version=0, qos=QoS.Default, protocol=Protocol.FRAGMENT, ttl=MeshProtocol.DefaultTTL, identity=struct.unpack('<H', secrets.token_bytes(2))[0], length=len(fragment), source=source, destination=destination, ) buffer = encode_packet(network_header, [fragment_header], fragment) self.network.send(network_header, buffer)
def handle_l4(self, network_header: NetworkHeader, stream: BytesIO): broadcast_header = BroadcastHeader.decode(stream) self.debug(f"Handling {broadcast_header} {network_header}") if self.ttl_cache.contains( hash((broadcast_header.source, broadcast_header.sequence))): return data = stream.read() # Now forward it to everyone we've heard from recently except who we heard it from neighbors = self.network.alive_neighbors() for neighbor in neighbors: if neighbor == broadcast_header.source: continue if neighbor == network_header.source: continue network_header = NetworkHeader( version=0, qos=QoS.Default, protocol=Protocol.BROADCAST, ttl=1, identity=self.network.next_sequence(), length=0, source=self.network.our_address, destination=neighbor, ) buffer = encode_partial_packet([broadcast_header], data) self.info(f"Forwarding {broadcast_header} to {neighbor}: {data}") self.reliable_manager.send(network_header, buffer) # Now deliver it locally self.info(f"Delivering {broadcast_header}: {data}") self.transport.handle_broadcast(broadcast_header.source, broadcast_header.port, data)
def send_datagram(self, source: MeshAddress, destination: MeshAddress, datagram_header: DatagramHeader, data: bytes, reliable: bool = False): if reliable: network_header = NetworkHeader( version=0, qos=QoS.Default, protocol=Protocol.RELIABLE, ttl=MeshProtocol.DefaultTTL, identity=struct.unpack('<H', secrets.token_bytes(2))[0], length=0, source=source, destination=destination, ) reliable_header = ReliableHeader(protocol=Protocol.DATAGRAM, flags=ReliableFlags.ACK, sequence=network_header.identity, acknowledged=[]) if len(data) + DatagramHeader.size() + reliable_header.size( ) > self.network.mtu(): raise RuntimeError( "Fragmentation not supported for reliable protocol") buffer = BytesIO() network_header.encode(buffer) reliable_header.encode(buffer) datagram_header.encode(buffer) buffer.write(data) buffer.seek(0) self.reliable_protocol.send(network_header, reliable_header, buffer.read()) else: if len(data) + DatagramHeader.size() > self.network.mtu(): self.fragment_protocol.fragment_datagram( source, destination, datagram_header, data) else: network_header = NetworkHeader( version=0, qos=QoS.Default, protocol=Protocol.DATAGRAM, ttl=MeshProtocol.DefaultTTL, identity=struct.unpack('<H', secrets.token_bytes(2))[0], length=0, source=source, destination=destination, ) buffer = BytesIO() network_header.encode(buffer) datagram_header.encode(buffer) buffer.write(data) buffer.seek(0) self.network.send(network_header, buffer.read())
def handle_l4(self, network_header: NetworkHeader, stream: BytesIO): reliable_header = ReliableHeader.decode(stream) self.debug(f"Handling reliable {reliable_header}") if reliable_header.flags & ReliableFlags.ACK: # We received an ACK self.reliable_manager.handle_ack(network_header.source, reliable_header.sequence) else: # An ACK is being requested if network_header.source not in self.network.alive_neighbors(): # Ignore if we haven't seen this neighbor before, we need discovery first # TODO send discovery to this neighbor self.debug( f"Ignoring {reliable_header} since we don't know this neighbor." ) return outgoing_network_header = NetworkHeader( version=0, qos=QoS.Default, protocol=Protocol.RELIABLE, ttl=1, identity=self.network.next_sequence(), length=0, source=self.network.our_address, destination=network_header.source, ) ack_header = ReliableHeader( protocol=Protocol.NONE, flags=ReliableFlags.ACK, sequence=reliable_header.sequence, ) response = BytesIO() outgoing_network_header.encode(response) ack_header.encode(response) response.seek(0) buffer = response.read() self.info( f"Sending Ack for sequence {reliable_header.sequence} to address {outgoing_network_header.source}" ) self.network.send(outgoing_network_header, buffer) self.handlers.handle_l4(network_header, reliable_header.protocol, stream)
def test_ttl_cache(self): time = MockTime() cache = TTLCache(time, 10) header = NetworkHeader(version=0, protocol=Protocol.DATAGRAM, qos=QoS.Default, ttl=4, identity=42, length=0, source=MeshAddress(1), destination=MeshAddress(2)) self.assertFalse(cache.contains(hash(header))) self.assertTrue(cache.contains(hash(header))) same_header = NetworkHeader(version=0, protocol=Protocol.DATAGRAM, qos=QoS.Default, ttl=4, identity=42, length=0, source=MeshAddress(1), destination=MeshAddress(2)) self.assertTrue(cache.contains(hash(same_header))) diff_header = NetworkHeader(version=0, protocol=Protocol.DATAGRAM, qos=QoS.Default, ttl=4, identity=43, length=0, source=MeshAddress(1), destination=MeshAddress(2)) self.assertFalse(cache.contains(hash(diff_header))) self.assertTrue(cache.contains(hash(diff_header))) time.sleep(11) self.assertFalse(cache.contains(hash(header))) self.assertFalse(cache.contains(hash(diff_header)))
def send_hello(self): self.debug("Sending Hello") hello = HelloHeader(self.config.get("host.name"), self.alive_neighbors()) network_header = NetworkHeader( version=0, qos=QoS.Lower, protocol=Protocol.HELLO, ttl=1, identity=self.next_sequence(), length=0, source=self.our_address, destination=self.BroadcastAddress, ) stream = BytesIO() network_header.encode(stream) hello.encode(stream) stream.seek(0) buffer = stream.read() self.send(network_header, buffer)
def test_encode_decode_datagram(self): msg = "Hello, World!".encode("utf-8") datagram_header_1 = DatagramHeader(source=100, destination=100, length=len(msg), checksum=crc_b(msg)) header1 = NetworkHeader(version=0, protocol=Protocol.DATAGRAM, qos=QoS.Default, ttl=4, identity=42, length=datagram_header_1.size() + len(msg), source=MeshAddress(1), destination=MeshAddress(2)) data = encode_packet(header1, [datagram_header_1], msg) stream = BytesIO(data) header2 = NetworkHeader.decode(stream) datagram2 = DatagramHeader.decode(stream) self.assertEqual(header1, header2) self.assertEqual(datagram_header_1, datagram2)
def send_ping(self, node: MeshAddress) -> int: ctrl = ControlHeader(True, ControlType.PING, 1, bytes([secure_random_byte()])) network_header = NetworkHeader( version=0, qos=QoS.Higher, protocol=Protocol.CONTROL, ttl=7, identity=self.network.next_sequence(), length=0, source=self.network.our_address, destination=node, ) stream = BytesIO() network_header.encode(stream) ctrl.encode(stream) stream.seek(0) buffer = stream.read() self.stats[node].results.append( PingResult(ctrl.extra[0], time.time_ns(), None, threading.Condition(self.mutex))) self.network.send(network_header, buffer) return ctrl.extra[0]
def test_send_receive(self): network_header = NetworkHeader( version=0, qos=QoS.Lower, protocol=Protocol.DATAGRAM, ttl=3, identity=10, length=0, source=MeshAddress(1), destination=MeshAddress(4), ) msg = "Hello, Node 4".encode("utf-8") datagram_header = DatagramHeader(source=100, destination=100, length=len(msg), checksum=0) stream = BytesIO() network_header.encode(stream) datagram_header.encode(stream) stream.write(msg) stream.seek(0) captured = None class MockTransportManager(L4Handler): def handle_l4(self, network_header: NetworkHeader, stream: BytesIO): nonlocal captured DatagramHeader.decode(stream) captured = stream.read() self.node_4.l4_handlers.handlers[ Protocol.DATAGRAM] = MockTransportManager() self.node_1.send(network_header, stream.read()) self.for_each_node(self.drain_queue) self.assertEqual(captured.decode("utf-8"), "Hello, Node 4")
def send_broadcast(self, port: int, data: bytes): neighbors = self.network.alive_neighbors() if len(neighbors) == 0: self.debug(f"Skipping broadcast since we have no neighbors.") broadcast_header = BroadcastHeader(source=self.network.our_address, port=port, sequence=self.next_sequence(), length=len(data), checksum=crc_b(data)) for neighbor in neighbors: network_header = NetworkHeader( version=0, qos=QoS.Default, protocol=Protocol.BROADCAST, ttl=1, identity=self.network.next_sequence(), length=0, source=self.network.our_address, destination=neighbor, ) buffer = encode_partial_packet([broadcast_header], data) self.info(f"Broadcasting {broadcast_header} to {neighbor}: {data}") self.reliable_manager.send(network_header, buffer)
def check_acks(self): with self.not_empty: while len(self.sent) == 0: self.not_empty.wait() item = self.sent[0] item.attempt += 1 if item.attempt > ReliableManager.MaxResend: self.info(f"Expiring {item}") self.sent.remove(item) self.network.failed_send(item.header.destination) self.not_full.notify() else: self.info(f"Resending {item}") stream = BytesIO(item.buffer) network_header = NetworkHeader.decode(stream) network_header = dataclasses.replace( network_header, identity=self.network.next_sequence()) stream.seek(0) network_header.encode(stream) stream.seek(0) self.network.send(network_header, stream.read()) # TODO backoff self.timer.reset()
def process_incoming(self, payload: L2Payload): stream = BytesIO(payload.l3_data) try: network_header = NetworkHeader.decode(stream) except Exception as e: self.error(f"Could not decode network packet from {payload}.", e) return # Handle L3 protocols first if network_header.destination == self.our_address and network_header.protocol == Protocol.CONTROL: ctrl = ControlHeader.decode(stream) self.info(f"Got {ctrl} from {network_header.source}") if ctrl.control_type == ControlType.PING: self.ping_protocol.handle_ping(network_header, ctrl) else: self.warning( f"Ignoring unsupported control packet: {ctrl.control_type}" ) return if network_header.protocol == Protocol.HELLO: self.handle_hello(payload.link_id, network_header, HelloHeader.decode(stream)) return if network_header.protocol == Protocol.LINK_STATE: self.handle_advertisement( payload.link_id, network_header, LinkStateAdvertisementHeader.decode(stream)) return if network_header.protocol == Protocol.LINK_STATE_QUERY: self.handle_query(payload.link_id, network_header, LinkStateQueryHeader.decode(stream)) return # Now decide if we should handle or drop if self.header_cache.contains(hash(network_header)): self.debug(f"Dropping duplicate {network_header}") return # If the packet is addressed to us, handle it if network_header.destination == self.our_address: self.debug(f"Handling {network_header}") self.l4_handlers.handle_l4(network_header, network_header.protocol, stream) return if network_header.destination == self.BroadcastAddress: self.debug(f"Handling broadcast {network_header}") self.l4_handlers.handle_l4(network_header, network_header.protocol, stream) if network_header.ttl > 1: # Decrease the TTL and re-broadcast on all links except where we heard it header_copy = dataclasses.replace(network_header, ttl=network_header.ttl - 1) stream.seek(0) header_copy.encode(stream) stream.seek(0) self.send(header_copy, stream.read(), exclude_link_id=payload.link_id) else: self.debug("Not re-broadcasting due to TTL") else: header_copy = dataclasses.replace(network_header, ttl=network_header.ttl - 1) stream.seek(0) header_copy.encode(stream) stream.seek(0) self.send(header_copy, stream.read(), exclude_link_id=payload.link_id)
def handle_query(self, link_id: int, network_header: NetworkHeader, query: LinkStateQueryHeader): self.debug(f"Handling {query}") dest = network_header.source adverts = [] # Check our local cache of link states for node, link_states in self.valid_link_states().items(): if node == dest: continue if node not in query.link_nodes: # Sender is missing this node's data header = LinkStateAdvertisementHeader( node=node, name=self.host_names.get(node, "unknown"), epoch=self.link_state_epochs.get(node), link_states=list(self.link_states.get(node))) adverts.append(header) else: idx = query.link_nodes.index(node) epoch = query.link_epochs[idx] if lollipop_compare(epoch, self.link_state_epochs[node]) > -1: header = LinkStateAdvertisementHeader( node=node, name=self.host_names.get(node, "unknown"), epoch=self.link_state_epochs.get(node), link_states=list(self.link_states.get(node))) adverts.append(header) # Include our latest state if they don't have it if self.our_address in query.link_nodes: idx = query.link_nodes.index(self.our_address) epoch = query.link_epochs[idx] if lollipop_compare(epoch, self.our_link_state_epoch) > -1: link_states = self.generate_our_adverts() our_advert = LinkStateAdvertisementHeader( node=self.our_address, name=self.host_name, epoch=self.our_link_state_epoch, link_states=link_states) adverts.append(our_advert) dest = network_header.source for advert in adverts: self.debug(f"Sending {advert} advert to {network_header.source}") resp_header = NetworkHeader( version=0, qos=QoS.Lower, protocol=Protocol.LINK_STATE, ttl=1, identity=self.next_sequence(), length=0, source=self.our_address, destination=dest, ) stream = BytesIO() resp_header.encode(stream) advert.encode(stream) stream.seek(0) buffer = stream.read() self.send(resp_header, buffer)