def do_tty(self, args): node = MeshAddress.parse(args.destination) port = args.port if node == MeshAddress.parse("ff.ff"): self.tty = self.transport_manager.broadcast( partial(ShellDatagramProtocol, self.stdout, self.prompt), self.network.our_address, port) else: self.tty = self.transport_manager.connect( partial(ShellDatagramProtocol, self.stdout, self.prompt), self.network.our_address, node, port) self.tty_port = port self.prompt = f"(tarpn {node}) "
def decode(cls, data: BytesIO): node = MeshAddress(ByteUtils.read_uint16(data)) epoch = ByteUtils.read_int8(data) count = ByteUtils.read_uint8(data) link_nodes = [] link_epochs = [] for _ in range(count): link_nodes.append(MeshAddress(ByteUtils.read_uint16(data))) link_epochs.append(ByteUtils.read_int8(data)) return cls(node=node, epoch=epoch, link_nodes=link_nodes, link_epochs=link_epochs)
def decode(cls, data: BytesIO): name = ByteUtils.read_utf8(data) count = ByteUtils.read_uint8(data) neighbors = [] for _ in range(count): neighbors.append(MeshAddress(ByteUtils.read_uint16(data))) return cls(name=name, neighbors=neighbors)
def decode(cls, data: BytesIO): node = MeshAddress(ByteUtils.read_uint16(data)) name = ByteUtils.read_utf8(data) epoch = ByteUtils.read_int8(data) count = ByteUtils.read_uint8(data) link_states = [] for _ in range(count): link_states.append(LinkStateHeader.decode(data)) return cls(node=node, name=name, epoch=epoch, link_states=link_states)
def decode(cls, data: BytesIO): byte = ByteUtils.read_int8(data) version = ByteUtils.hi_bits(byte, 4) protocol = ByteUtils.lo_bits(byte, 4) byte = ByteUtils.read_int8(data) qos = ByteUtils.hi_bits(byte, 3) ttl = ByteUtils.lo_bits(byte, 5) identity = ByteUtils.read_uint16(data) length = ByteUtils.read_uint16(data) source = ByteUtils.read_uint16(data) dest = ByteUtils.read_uint16(data) return cls(version=version, protocol=Protocol(protocol), qos=qos, ttl=ttl, identity=identity, length=length, source=MeshAddress(source), destination=MeshAddress(dest))
def decode(cls, data: BytesIO): source = ByteUtils.read_uint16(data) port = ByteUtils.read_uint8(data) seq = ByteUtils.read_uint16(data) length = ByteUtils.read_uint16(data) checksum = ByteUtils.read_uint16(data) return cls(source=MeshAddress(source), port=port, sequence=seq, length=length, checksum=checksum)
def do_ping(self, args): node = MeshAddress.parse(args.destination) self.stdout.write(f"Sending ping ({len(self.network.alive_neighbors())})...\r\n") for _ in range(args.count): t0 = time.time_ns() seq = self.network.ping_protocol.send_ping(node) found = self.network.ping_protocol.wait_for_ping(node, seq, timeout_ms=args.timeout * 1000) t1 = time.time_ns() if found: dt = (t1-t0) / 1000000. self.stdout.write(f"Got response in {dt}ms\r\n") else: self.stdout.write(f"Timed out waiting for response\r\n")
def write(self, data: Any) -> None: _, mtu = self.network.route_packet(MeshAddress.parse("ff.ff")) if isinstance(data, str): encoded_data = data.encode("utf-8") elif isinstance(data, (bytes, bytearray)): encoded_data = data else: raise ValueError( "DatagramTransport.write only supports bytes and strings") if len(encoded_data) <= mtu: self.broadcast_protocol.send_broadcast(self.port, encoded_data) else: raise RuntimeError(f"Message too large, maximum size is {mtu}")
def test_send_receive(self): network_header = NetworkHeader( version=0, qos=QoS.Lower, protocol=Protocol.DATAGRAM, ttl=3, identity=10, length=0, source=MeshAddress(1), destination=MeshAddress(4), ) msg = "Hello, Node 4".encode("utf-8") datagram_header = DatagramHeader(source=100, destination=100, length=len(msg), checksum=0) stream = BytesIO() network_header.encode(stream) datagram_header.encode(stream) stream.write(msg) stream.seek(0) captured = None class MockTransportManager(L4Handler): def handle_l4(self, network_header: NetworkHeader, stream: BytesIO): nonlocal captured DatagramHeader.decode(stream) captured = stream.read() self.node_4.l4_handlers.handlers[ Protocol.DATAGRAM] = MockTransportManager() self.node_1.send(network_header, stream.read()) self.for_each_node(self.drain_queue) self.assertEqual(captured.decode("utf-8"), "Hello, Node 4")
def broadcast(self, protocol_factory: Callable[[], DProtocol], local_address: MeshAddress, port: int) -> DProtocol: if port in self.connections: if self.connections[port][0].is_closing(): del self.connections[port] else: raise RuntimeError(f"Connection to {port} is already open") protocol = protocol_factory() transport = BroadcastTransport(self.l3_protocol, self.broadcast_protocol, self.datagram_protocol, port, local_address, MeshAddress.parse("ff.ff")) protocol.connection_made(transport) self.connections[port] = (transport, protocol) return protocol
def __init__(self, time: Time, config: NetworkConfig, link_multiplexer: LinkMultiplexer, l4_handlers: L4Handlers, scheduler: Scheduler): LoggingMixin.__init__(self, extra_func=self.log_ident) CloseableThreadLoop.__init__(self, name="MeshNetwork") self.time = time self.config = config self.link_multiplexer = link_multiplexer self.l4_handlers = l4_handlers self.scheduler = scheduler self.queue = queue.Queue() self.our_address = MeshAddress.parse(config.get("mesh.address")) self.host_name = config.get("host.name") self.ping_protocol = PingProtocol(self) # TTL cache of seen frames from each source self.header_cache: TTLCache = TTLCache(time, 30_000) # Our own send sequence self.send_seq: int = 1 self.seq_lock = threading.Lock() # Link states and neighbors self.neighbors: Dict[MeshAddress, Neighbor] = dict() # An epoch for our own link state changes. Any time a neighbor comes or goes, or the quality changes, # we increment this counter. Uses a "lollipop sequence" to allow for easy detection of wrap around vs # reset self.our_link_state_epoch_generator = lollipop_sequence() self.our_link_state_epoch: int = next( self.our_link_state_epoch_generator) # Epochs we have received from other nodes and their link states self.link_state_epochs: Dict[MeshAddress, int] = dict() self.host_names: Dict[MeshAddress, str] = dict() self.link_states: Dict[MeshAddress, List[LinkStateHeader]] = dict() self.last_hello_time = datetime.fromtimestamp(0) self.last_advert = datetime.utcnow() self.last_query = datetime.utcnow() self.last_epoch_bump = datetime.utcnow() self.scheduler.timer(1_000, partial(self.scheduler.submit, self), True)
def write_to(self, address: str, data: Any) -> None: dest = MeshAddress.parse(address) can_route, mtu = self.network.route_packet(dest) if can_route: if isinstance(data, str): encoded_data = data.encode("utf-8") elif isinstance(data, (bytes, bytearray)): encoded_data = data else: raise ValueError( "DatagramTransport.write only supports bytes and strings") if len(encoded_data) <= mtu: header = DatagramHeader(source=self.port, destination=self.port, length=len(encoded_data), checksum=crc_b(encoded_data)) self.datagram_protocol.send_datagram(self.local, self.remote, header, encoded_data) #self.broadcast_protocol.send_broadcast(self.port, encoded_data) else: raise RuntimeError(f"Message too large, maximum size is {mtu}") else: raise RuntimeError(f"Cannot route to {dest}!")
def main(): parser = argparse.ArgumentParser( description='Broadcast to mesh network over serial device') parser.add_argument("device", help="Serial port to open") parser.add_argument("baud", type=int, help="Baudrate to use") parser.add_argument("callsign", help="Your callsign (e.g., K4DBZ-10)") parser.add_argument("address", help="Local address, e.g., 00.1a") parser.add_argument("port", type=int, help="Port", default=10) parser.add_argument("--debug", action="store_true") args = parser.parse_args() # Configure logging main_logger = logging.getLogger("root") main_logger.setLevel(logging.ERROR) main_logger.addHandler(logging.StreamHandler(sys.stdout)) if args.debug: main_logger.setLevel(logging.DEBUG) state_logger = logging.getLogger("ax25.state") state_logger.setLevel(logging.DEBUG) state_logger.addHandler(logging.StreamHandler(sys.stdout)) scheduler = Scheduler() # Configure and initialize I/O device and L2 port_config = PortConfig.from_dict( 0, { "port.enabled": True, "port.type": "serial", "serial.device": args.device, "serial.speed": args.baud }) # Initialize I/O device and L2 l3_protocols = L3Protocols() l2_multi = DefaultLinkMultiplexer(L3PriorityQueue, scheduler) l2_queueing = L2FIFOQueue(20, AX25Protocol.maximum_frame_size()) l2 = AX25Protocol(port_config, port_config.port_id(), AX25Call.parse(args.callsign), scheduler, l2_queueing, l2_multi, l3_protocols) kiss = KISSProtocol(port_config.port_id(), l2_queueing, port_config.get_boolean("kiss.checksum", False)) SerialDevice(kiss, port_config.get("serial.device"), port_config.get_int("serial.speed"), port_config.get_float("serial.timeout"), scheduler) scheduler.submit(L2IOLoop(l2_queueing, l2)) addr = MeshAddress.parse(args.address) mesh_l3 = MeshProtocol(our_address=addr, link_multiplexer=l2_multi, scheduler=scheduler) l3_protocols.register(mesh_l3) mesh_l4 = MeshTransportManager(mesh_l3) tty = TTY() loop = asyncio.get_event_loop() loop.add_reader(sys.stdin, tty.handle_stdin) loop.add_signal_handler(signal.SIGTERM, tty.handle_signal, loop, scheduler) loop.add_signal_handler(signal.SIGINT, tty.handle_signal, loop, scheduler) mesh_l4.connect(protocol_factory=lambda: tty, port=args.port, local_address=addr, remote_address=MeshProtocol.BroadcastAddress) try: loop.run_forever() finally: loop.close()
def decode(cls, data: BytesIO): node = MeshAddress(ByteUtils.read_uint16(data)) via = MeshAddress(ByteUtils.read_uint16(data)) quality = ByteUtils.read_uint8(data) return cls(node=node, via=via, quality=quality)
records.append(record) self.epochs[address] = record.epoch return True else: return False def get_link_states(self, node: MeshAddress) -> Dict[MeshAddress, int]: states = {} if node not in self.epochs.keys(): return states for record in self.records[node]: if isinstance(record, LinkRecord): link_record = cast(LinkRecord, record) states[link_record.source] = link_record.quality return states if __name__ == "__main__": log = Log() nodeA = MeshAddress.parse("00.aa") nodeB = MeshAddress.parse("00.bb") log.append(nodeA, LinkRecord(-128, nodeA, nodeB, 100)) log.append(nodeA, LinkRecord(-127, nodeA, nodeB, 99)) log.append(nodeA, LinkRecord(10, nodeA, nodeB, 98)) log.append(nodeA, LinkRecord(9, nodeA, nodeB, 97)) log.append(nodeA, LinkRecord(-128, nodeA, nodeB, 100)) print(log.get_link_states(nodeA)) print(log.get_link_states(nodeB)) print(log.records)
def run_node(args): # Bootstrap node.ini if not os.path.exists(args.config) and os.path.basename( args.config) == "node.ini": shutil.copyfile("config/node.ini.sample", args.config) # Load settings from ini file s = Settings(".", ["config/defaults.ini", args.config]) node_settings = s.node_config() node_call = AX25Call.parse(node_settings.node_call()) if node_call.callsign == "N0CALL": print("Callsign is missing from config. Please see instructions here " "https://github.com/tarpn/tarpn-node-controller") sys.exit(1) else: print(f"Loaded configuration for {node_call}") # Setup logging logging_config_file = node_settings.get("log.config", "not_set") if logging_config_file != "not_set": log_dir = node_settings.get("log.dir") if not os.path.exists(log_dir): os.makedirs(log_dir) logging.config.fileConfig(logging_config_file, defaults={"log.dir": log_dir}, disable_existing_loggers=False) if args.verbose: logging.getLogger("root").setLevel(logging.DEBUG) # Create thread pool scheduler = Scheduler() # Initialize I/O devices and L2 protocols l3_protocols = L3Protocols() l3_protocols.register(NoLayer3Protocol()) l2_multi = DefaultLinkMultiplexer(L3PriorityQueue, scheduler) # Port UDP mapping # udp.forwarding.enabled = true # udp.address = 192.168.0.160:10093 # udp.destinations = K4DBZ-2,NODES # udp.mapping = KN4ORB-2:1,KA2DEW-2:2 port_queues = {} if node_settings.get_boolean("udp.enabled", False): udp_host, udp_port = node_settings.get("udp.address").split(":") udp_port = int(udp_port) udp_writer = UDPWriter(g8bpq_address=(udp_host, udp_port)) intercept_dests = { AX25Call.parse(c) for c in node_settings.get("udp.destinations", "").split(",") } interceptor = udp_writer.receive_frame udp_mapping = {} for mapping in node_settings.get("udp.mapping", "").split(","): c, i = mapping.split(":") udp_mapping[AX25Call.parse(c)] = int(i) scheduler.submit( UDPThread("0.0.0.0", udp_port, udp_mapping, port_queues, udp_writer)) else: intercept_dests = {} interceptor = lambda frame: None for port_config in s.port_configs(): if port_config.get_boolean("port.enabled") and port_config.get( "port.type") == "serial": l2_queueing = L2FIFOQueue(20, AX25Protocol.maximum_frame_size()) port_queues[port_config.port_id()] = l2_queueing l2 = AX25Protocol(port_config, port_config.port_id(), node_call, scheduler, l2_queueing, l2_multi, l3_protocols, intercept_dests, interceptor) kiss = KISSProtocol( port_config.port_id(), l2_queueing, port_config.get_boolean("kiss.checksum", False)) SerialDevice(kiss, port_config.get("serial.device"), port_config.get_int("serial.speed"), port_config.get_float("serial.timeout"), scheduler) scheduler.submit(L2IOLoop(l2_queueing, l2)) # Register L3 protocols routing_table = tarpn.netrom.router.NetRomRoutingTable.load( f"nodes-{node_settings.node_call()}.json", node_settings.node_alias()) network_configs = s.network_configs() if network_configs.get_boolean("netrom.enabled", False): logger.info("Starting NET/ROM router") netrom_l3 = NetRomL3(node_call, node_settings.node_alias(), scheduler, l2_multi, routing_table) l3_protocols.register(netrom_l3) netrom_l4 = NetRomTransportProtocol(s.network_configs(), netrom_l3, scheduler) l4_handlers = L4Handlers() if network_configs.get_boolean("mesh.enabled", False): mesh_l3 = MeshProtocol(WallTime(), network_configs, l2_multi, l4_handlers, scheduler) l3_protocols.register(mesh_l3) # Create the L4 protocols mesh_l4 = MeshTransportManager(mesh_l3) # Register L4 handlers reliable = ReliableManager(mesh_l3, scheduler) fragment_protocol = FragmentProtocol(mesh_l3, mesh_l4) reliable_protocol = ReliableProtocol(mesh_l3, reliable, l4_handlers) datagram_protocol = DatagramProtocol(mesh_l3, mesh_l4, fragment_protocol, reliable_protocol) broadcast_protocol = BroadcastProtocol(mesh_l3, mesh_l4, reliable) l4_handlers.register_l4(Protocol.FRAGMENT, fragment_protocol) l4_handlers.register_l4(Protocol.RELIABLE, reliable_protocol) l4_handlers.register_l4(Protocol.DATAGRAM, datagram_protocol) l4_handlers.register_l4(Protocol.BROADCAST, broadcast_protocol) # TODO fix circular dependency here mesh_l4.broadcast_protocol = broadcast_protocol mesh_l4.datagram_protocol = datagram_protocol # Bind the command processor ncp_factory = partial(NodeCommandProcessor, config=network_configs, link=l2_multi, network=mesh_l3, transport_manager=mesh_l4, scheduler=scheduler) mesh_l4.bind(ncp_factory, mesh_l3.our_address, 11) # Set up applications for app_config in s.app_configs(): # We have a single unix socket connection multiplexed to many network connections print(app_config) app_multiplexer = TransportMultiplexer() app_address = MeshTransportAddress.parse( app_config.get("app.address")) app_protocol = ApplicationProtocol(app_config.app_name(), app_config.app_alias(), str(app_address.address), mesh_l4, app_multiplexer) scheduler.submit( UnixServerThread(app_config.app_socket(), app_protocol)) multiplexer_protocol = partial(MultiplexingProtocol, app_multiplexer) # TODO bind or connect? mesh_l4.connect(multiplexer_protocol, app_address.address, MeshAddress.parse("00.a2"), app_address.port) sock = node_settings.get("node.sock") print(f"Binding node terminal to {sock}") scheduler.submit( UnixServerThread(sock, TarpnShellProtocol(mesh_l3, mesh_l4))) # Start a metrics reporter #reporter = ConsoleReporter(reporting_interval=300) #scheduler.timer(10_000, reporter.start, True) #scheduler.add_shutdown_hook(reporter.stop) logger.info("Finished Startup") try: # Wait for all threads scheduler.join() except KeyboardInterrupt: scheduler.shutdown()
def parse(cls, s: str): uri_parts = urllib.parse.urlsplit(s) assert uri_parts.scheme == "mesh" address = MeshAddress.parse(uri_parts.hostname) return cls(address=address, port=uri_parts.port)
class MeshProtocol(CloseableThreadLoop, L3Protocol, LoggingMixin): """ A simple protocol for a partially connected mesh network. Nodes send HELLO packets to their neighbors frequently. This is used as a way to initialize a link to a neighbor and as a failure detector. Once a neighbor is discovered, a ADVERTISE packet is sent which informs the neighbor of this node's current state. If a node receives an ADVERTISE with a newer epoch, it will forward that packet so other nodes can learn about this new state. """ ProtocolId = 0xB0 WindowSize = 1024 MaxFragments = 8 HeaderBytes = 10 DefaultTTL = 7 BroadcastAddress = MeshAddress(0xFFFF) def __init__(self, time: Time, config: NetworkConfig, link_multiplexer: LinkMultiplexer, l4_handlers: L4Handlers, scheduler: Scheduler): LoggingMixin.__init__(self, extra_func=self.log_ident) CloseableThreadLoop.__init__(self, name="MeshNetwork") self.time = time self.config = config self.link_multiplexer = link_multiplexer self.l4_handlers = l4_handlers self.scheduler = scheduler self.queue = queue.Queue() self.our_address = MeshAddress.parse(config.get("mesh.address")) self.host_name = config.get("host.name") self.ping_protocol = PingProtocol(self) # TTL cache of seen frames from each source self.header_cache: TTLCache = TTLCache(time, 30_000) # Our own send sequence self.send_seq: int = 1 self.seq_lock = threading.Lock() # Link states and neighbors self.neighbors: Dict[MeshAddress, Neighbor] = dict() # An epoch for our own link state changes. Any time a neighbor comes or goes, or the quality changes, # we increment this counter. Uses a "lollipop sequence" to allow for easy detection of wrap around vs # reset self.our_link_state_epoch_generator = lollipop_sequence() self.our_link_state_epoch: int = next( self.our_link_state_epoch_generator) # Epochs we have received from other nodes and their link states self.link_state_epochs: Dict[MeshAddress, int] = dict() self.host_names: Dict[MeshAddress, str] = dict() self.link_states: Dict[MeshAddress, List[LinkStateHeader]] = dict() self.last_hello_time = datetime.fromtimestamp(0) self.last_advert = datetime.utcnow() self.last_query = datetime.utcnow() self.last_epoch_bump = datetime.utcnow() self.scheduler.timer(1_000, partial(self.scheduler.submit, self), True) def __repr__(self): return f"<MeshProtocol {self.our_address}>" def log_ident(self) -> str: return f"[MeshProtocol {self.our_address}]" def next_sequence(self) -> int: with self.seq_lock: seq = self.send_seq self.send_seq += 1 return seq % MeshProtocol.WindowSize def next_sequences(self, n) -> List[int]: seqs = [] with self.seq_lock: for i in range(n): seqs.append(self.send_seq % MeshProtocol.WindowSize) self.send_seq += n return seqs def neighbors(self, since=300) -> Set[MeshAddress]: return set(self.neighbors.keys()) def up_neighbors(self) -> List[MeshAddress]: return [ n.address for n in self.neighbors.values() if n.state == NeighborState.UP ] def alive_neighbors(self) -> List[MeshAddress]: return [ n.address for n in self.neighbors.values() if n.state in (NeighborState.UP, NeighborState.INIT) ] def down_neighbors(self) -> List[MeshAddress]: return [ n.address for n in self.neighbors.values() if n.state == NeighborState.DOWN ] def valid_link_states(self) -> Dict[MeshAddress, List[LinkStateHeader]]: now = datetime.utcnow() result = dict() for node, link_states in self.link_states.items(): result[node] = [] for link_state in link_states: if (now - link_state.created).seconds < 300: result[node].append(link_state) return result def can_handle(self, protocol: int) -> bool: return protocol == MeshProtocol.ProtocolId def pre_close(self): # Erase our neighbors and send ADVERT self.neighbors.clear() self.our_link_state_epoch = next(self.our_link_state_epoch_generator) self.send_advertisement() time.sleep( 1 ) # TODO better solution is to wait for L3 queue to drain in close def close(self): self.wakeup() CloseableThreadLoop.close(self) def iter_loop(self) -> bool: # Check if we need to take some periodic action like sending a HELLO now = datetime.utcnow() deadline = self.deadline(now) # Now wait at most the deadline for the next action for new incoming packets try: event = self.queue.get(block=True, timeout=deadline) if event is not None: self.process_incoming(event) return True except queue.Empty: return False def deadline(self, now: datetime) -> int: # TODO use time.time_ns instead of datetime return min([ self.check_dead_neighbors(now), self.check_hello(now), self.check_epoch(now), self.check_advert(now), self.check_query(now) ]) def wakeup(self): """Wake up the main thread""" self.queue.put(None) def check_dead_neighbors(self, now: datetime) -> int: min_deadline = self.config.get_int("mesh.dead.interval") for neighbor in list(self.neighbors.values()): if neighbor.state == NeighborState.DOWN: continue deadline = self.config.get_int("mesh.dead.interval") - ( now - neighbor.last_seen).seconds if deadline <= 0: self.info(f"Neighbor {neighbor.address} detected as DOWN!") neighbor.state = NeighborState.DOWN self.our_link_state_epoch = next( self.our_link_state_epoch_generator) self.last_epoch_bump = datetime.utcnow() self.last_advert = datetime.fromtimestamp( 0) # Force our advert to go out else: min_deadline = min(deadline, min_deadline) return min_deadline def check_hello(self, now: datetime) -> int: deadline = self.config.get_int("mesh.hello.interval") - ( now - self.last_hello_time).seconds if deadline <= 0: self.send_hello() self.last_hello_time = now return self.config.get_int("mesh.hello.interval") else: return deadline def check_epoch(self, now: datetime) -> int: max_age = self.config.get_int("mesh.advert.max.age") to_delete = [] for node, links in self.link_states.items(): for link in list(links): if (now - link.created).seconds > max_age: self.debug(f"Expiring link state {link} for {node}") links.remove(link) if len(links) == 0: to_delete.append(node) for node in to_delete: del self.link_states[node] deadline = int(max_age * .80) - (now - self.last_epoch_bump).seconds if deadline <= 0: self.our_link_state_epoch = next( self.our_link_state_epoch_generator) self.last_epoch_bump = now self.send_advertisement() return int(max_age * .80) else: return deadline def check_advert(self, now: datetime) -> int: deadline = self.config.get_int("mesh.advert.interval") - ( now - self.last_advert).seconds if deadline > 0: return deadline else: self.send_advertisement() self.last_advert = now return self.config.get_int("mesh.advert.interval") def check_query(self, now: datetime) -> int: deadline = self.config.get_int("mesh.query.interval") - ( now - self.last_query).seconds if deadline > 0: return deadline else: for neighbor in self.up_neighbors(): self.send_query(neighbor) self.last_query = now return self.config.get_int("mesh.query.interval") def handle_l2_payload(self, payload: L2Payload): """ Handling an inbound packet from L2. We add this to the queue which wakes up the thread to process this packet """ self.queue.put(payload) def process_incoming(self, payload: L2Payload): stream = BytesIO(payload.l3_data) try: network_header = NetworkHeader.decode(stream) except Exception as e: self.error(f"Could not decode network packet from {payload}.", e) return # Handle L3 protocols first if network_header.destination == self.our_address and network_header.protocol == Protocol.CONTROL: ctrl = ControlHeader.decode(stream) self.info(f"Got {ctrl} from {network_header.source}") if ctrl.control_type == ControlType.PING: self.ping_protocol.handle_ping(network_header, ctrl) else: self.warning( f"Ignoring unsupported control packet: {ctrl.control_type}" ) return if network_header.protocol == Protocol.HELLO: self.handle_hello(payload.link_id, network_header, HelloHeader.decode(stream)) return if network_header.protocol == Protocol.LINK_STATE: self.handle_advertisement( payload.link_id, network_header, LinkStateAdvertisementHeader.decode(stream)) return if network_header.protocol == Protocol.LINK_STATE_QUERY: self.handle_query(payload.link_id, network_header, LinkStateQueryHeader.decode(stream)) return # Now decide if we should handle or drop if self.header_cache.contains(hash(network_header)): self.debug(f"Dropping duplicate {network_header}") return # If the packet is addressed to us, handle it if network_header.destination == self.our_address: self.debug(f"Handling {network_header}") self.l4_handlers.handle_l4(network_header, network_header.protocol, stream) return if network_header.destination == self.BroadcastAddress: self.debug(f"Handling broadcast {network_header}") self.l4_handlers.handle_l4(network_header, network_header.protocol, stream) if network_header.ttl > 1: # Decrease the TTL and re-broadcast on all links except where we heard it header_copy = dataclasses.replace(network_header, ttl=network_header.ttl - 1) stream.seek(0) header_copy.encode(stream) stream.seek(0) self.send(header_copy, stream.read(), exclude_link_id=payload.link_id) else: self.debug("Not re-broadcasting due to TTL") else: header_copy = dataclasses.replace(network_header, ttl=network_header.ttl - 1) stream.seek(0) header_copy.encode(stream) stream.seek(0) self.send(header_copy, stream.read(), exclude_link_id=payload.link_id) def handle_hello(self, link_id: int, network_header: NetworkHeader, hello: HelloHeader): self.debug(f"Handling hello {hello}") now = datetime.utcnow() sender = network_header.source if network_header.source not in self.neighbors: self.info(f"Saw new neighbor {sender} ({hello.name})") self.neighbors[sender] = Neighbor(address=sender, name=hello.name, link_id=link_id, neighbors=hello.neighbors, last_seen=now, last_update=now, state=NeighborState.INIT) self.our_link_state_epoch = next( self.our_link_state_epoch_generator) self.last_epoch_bump = now else: self.neighbors[sender].neighbors = hello.neighbors self.neighbors[sender].last_seen = now if self.our_address in hello.neighbors: delay = 100 if self.neighbors[sender].state != NeighborState.UP: self.info(f"Neighbor {sender} is UP!") self.scheduler.timer(delay, partial(self.send_query, sender), auto_start=True) self.neighbors[sender].state = NeighborState.UP self.neighbors[sender].last_update = now delay *= 1.2 else: self.info(f"Neighbor {sender} is initializing...") self.neighbors[sender].state = NeighborState.INIT self.neighbors[sender].last_update = now def send_hello(self): self.debug("Sending Hello") hello = HelloHeader(self.config.get("host.name"), self.alive_neighbors()) network_header = NetworkHeader( version=0, qos=QoS.Lower, protocol=Protocol.HELLO, ttl=1, identity=self.next_sequence(), length=0, source=self.our_address, destination=self.BroadcastAddress, ) stream = BytesIO() network_header.encode(stream) hello.encode(stream) stream.seek(0) buffer = stream.read() self.send(network_header, buffer) def handle_advertisement(self, link_id: int, network_header: NetworkHeader, advert: LinkStateAdvertisementHeader): if advert.node == self.our_address: return latest_epoch = self.link_state_epochs.get(advert.node) if latest_epoch is None: self.debug( f"Initializing link state for {advert.node} with epoch {advert.epoch}" ) update = True else: epoch_cmp = lollipop_compare(latest_epoch, advert.epoch) update = epoch_cmp > -1 if epoch_cmp == 1: self.debug( f"Updating link state for {advert.node}. " f"Received epoch is {advert.epoch}, last known was {latest_epoch}" ) elif epoch_cmp == 0: self.debug( f"Resetting link state for {advert.node}. " f"Received epoch is {advert.epoch}, last known was {latest_epoch}" ) else: self.debug( f"Ignoring stale link state for {advert.node}. " f"Received epoch is {advert.epoch}, last known was {latest_epoch}" ) if update: self.link_states[advert.node] = advert.link_states self.link_state_epochs[advert.node] = advert.epoch self.host_names[advert.node] = advert.name # Forward this packet to all neighbors except where we heard it if network_header.ttl > 1: network_header_copy = dataclasses.replace( network_header, ttl=network_header.ttl - 1) stream = BytesIO() network_header_copy.encode(stream) advert.encode(stream) stream.seek(0) buffer = stream.read() self.send(network_header_copy, buffer, exclude_link_id=link_id) def generate_our_adverts(self) -> List[LinkStateHeader]: link_states = [] for address in self.up_neighbors(): cost = self.link_multiplexer.get_link_cost( self.neighbors.get(address).link_id) link_states.append( LinkStateHeader(node=address, via=self.our_address, quality=cost)) return link_states def send_advertisement(self): link_states = self.generate_our_adverts() advertisement = LinkStateAdvertisementHeader( node=self.our_address, name=self.host_name, epoch=self.our_link_state_epoch, link_states=link_states) self.debug("Sending Advertisement {}".format(advertisement)) network_header = NetworkHeader( version=0, qos=QoS.Lower, protocol=Protocol.LINK_STATE, ttl=7, identity=self.next_sequence(), length=0, source=self.our_address, destination=self.BroadcastAddress, ) stream = BytesIO() network_header.encode(stream) advertisement.encode(stream) stream.seek(0) buffer = stream.read() self.send(network_header, buffer) def handle_query(self, link_id: int, network_header: NetworkHeader, query: LinkStateQueryHeader): self.debug(f"Handling {query}") dest = network_header.source adverts = [] # Check our local cache of link states for node, link_states in self.valid_link_states().items(): if node == dest: continue if node not in query.link_nodes: # Sender is missing this node's data header = LinkStateAdvertisementHeader( node=node, name=self.host_names.get(node, "unknown"), epoch=self.link_state_epochs.get(node), link_states=list(self.link_states.get(node))) adverts.append(header) else: idx = query.link_nodes.index(node) epoch = query.link_epochs[idx] if lollipop_compare(epoch, self.link_state_epochs[node]) > -1: header = LinkStateAdvertisementHeader( node=node, name=self.host_names.get(node, "unknown"), epoch=self.link_state_epochs.get(node), link_states=list(self.link_states.get(node))) adverts.append(header) # Include our latest state if they don't have it if self.our_address in query.link_nodes: idx = query.link_nodes.index(self.our_address) epoch = query.link_epochs[idx] if lollipop_compare(epoch, self.our_link_state_epoch) > -1: link_states = self.generate_our_adverts() our_advert = LinkStateAdvertisementHeader( node=self.our_address, name=self.host_name, epoch=self.our_link_state_epoch, link_states=link_states) adverts.append(our_advert) dest = network_header.source for advert in adverts: self.debug(f"Sending {advert} advert to {network_header.source}") resp_header = NetworkHeader( version=0, qos=QoS.Lower, protocol=Protocol.LINK_STATE, ttl=1, identity=self.next_sequence(), length=0, source=self.our_address, destination=dest, ) stream = BytesIO() resp_header.encode(stream) advert.encode(stream) stream.seek(0) buffer = stream.read() self.send(resp_header, buffer) def send_query(self, neighbor: MeshAddress): self.debug(f"Querying {neighbor} for link states") known_link_states = dict() for node, link_states in self.valid_link_states().items(): known_link_states[node] = self.link_state_epochs[node] query = LinkStateQueryHeader(node=self.our_address, epoch=self.our_link_state_epoch, link_nodes=list(known_link_states.keys()), link_epochs=list( known_link_states.values())) network_header = NetworkHeader( version=0, qos=QoS.Lower, protocol=Protocol.LINK_STATE_QUERY, ttl=1, identity=self.next_sequence(), length=0, source=self.our_address, destination=neighbor, ) stream = BytesIO() network_header.encode(stream) query.encode(stream) stream.seek(0) buffer = stream.read() self.send(network_header, buffer) def route_to(self, address: MeshAddress) -> List[MeshAddress]: g = nx.DiGraph() # Other's links for node, link_states in self.valid_link_states().items(): for link_state in link_states: g.add_weighted_edges_from([(node, link_state.node, link_state.quality)]) # Our links for neighbor in self.up_neighbors(): cost = self.link_multiplexer.get_link_cost( self.neighbors.get(neighbor).link_id) g.add_weighted_edges_from([(self.our_address, neighbor, cost)]) try: # Compute the shortest path path = nx.dijkstra_path(g, self.our_address, address) # Ensure we have a return path nx.dijkstra_path(g, address, self.our_address) return path except NetworkXException: return [] def send(self, header: NetworkHeader, buffer: bytes, exclude_link_id: Optional[int] = None): """ Send a packet to a network destination. If the destination address is ff.ff, the packet is broadcast on all available L2 links (optionally excluding a given link). :param header the header of the packet to broadcast :param buffer the entire buffer of the packet to broadcast :param exclude_link_id an L2 link to exclude from the broadcast """ if header.destination == MeshProtocol.BroadcastAddress: links = self.link_multiplexer.links_for_address( AX25Address("TARPN"), exclude_link_id) elif header.destination in self.up_neighbors(): neighbor = self.neighbors.get(header.destination) links = [neighbor.link_id] else: best_route = self.route_to(header.destination) self.debug(f"Routing {header} via {best_route}") if len(best_route) > 1: next_hop = best_route[1] hop_neighbor = self.neighbors.get(next_hop) if hop_neighbor is not None: links = [hop_neighbor.link_id] else: self.error( f"Calculated route including {next_hop}, but we're missing that neighbor." ) links = [] else: self.warning(f"No route to {header.destination}, dropping.") links = [] for link_id in links: payload = L3Payload(source=header.source, destination=header.destination, protocol=MeshProtocol.ProtocolId, buffer=buffer, link_id=link_id, qos=QoS(header.qos), reliable=False) self.debug(f"Sending {payload}") self.link_multiplexer.offer(payload) def failed_send(self, neighbor: MeshAddress): self.debug(f"Marking neighbor {neighbor} as failed.") def mtu(self): # We want uniform packets, so get the min L2 MTU return self.link_multiplexer.mtu() - MeshProtocol.HeaderBytes def route_packet(self, address: L3Address) -> Tuple[bool, int]: # Subtract L3 header size and multiply by max fragments l3_mtu = MeshProtocol.MaxFragments * (self.mtu() - MeshProtocol.HeaderBytes) if address == MeshProtocol.BroadcastAddress: return True, l3_mtu else: path = self.route_to(cast(MeshAddress, address)) return len(path) > 0, l3_mtu def send_packet(self, payload: L3Payload) -> bool: return self.link_multiplexer.offer(payload) def listen(self, address: MeshAddress): # By default we listen for all addresses pass def register_transport_protocol(self, protocol) -> None: # TODO remove this from L3Protocol pass