Ejemplo n.º 1
0
    def send_advertisement(self):
        link_states = self.generate_our_adverts()

        advertisement = LinkStateAdvertisementHeader(
            node=self.our_address,
            name=self.host_name,
            epoch=self.our_link_state_epoch,
            link_states=link_states)

        self.debug("Sending Advertisement {}".format(advertisement))

        network_header = NetworkHeader(
            version=0,
            qos=QoS.Lower,
            protocol=Protocol.LINK_STATE,
            ttl=7,
            identity=self.next_sequence(),
            length=0,
            source=self.our_address,
            destination=self.BroadcastAddress,
        )

        stream = BytesIO()
        network_header.encode(stream)
        advertisement.encode(stream)
        stream.seek(0)
        buffer = stream.read()
        self.send(network_header, buffer)
Ejemplo n.º 2
0
    def send_query(self, neighbor: MeshAddress):
        self.debug(f"Querying {neighbor} for link states")
        known_link_states = dict()
        for node, link_states in self.valid_link_states().items():
            known_link_states[node] = self.link_state_epochs[node]

        query = LinkStateQueryHeader(node=self.our_address,
                                     epoch=self.our_link_state_epoch,
                                     link_nodes=list(known_link_states.keys()),
                                     link_epochs=list(
                                         known_link_states.values()))

        network_header = NetworkHeader(
            version=0,
            qos=QoS.Lower,
            protocol=Protocol.LINK_STATE_QUERY,
            ttl=1,
            identity=self.next_sequence(),
            length=0,
            source=self.our_address,
            destination=neighbor,
        )

        stream = BytesIO()
        network_header.encode(stream)
        query.encode(stream)
        stream.seek(0)
        buffer = stream.read()
        self.send(network_header, buffer)
Ejemplo n.º 3
0
    def test_encode_decode_fragment(self):
        msg = "Hello, World!".encode("utf-8")
        datagram1 = DatagramHeader(source=100,
                                   destination=100,
                                   length=len(msg),
                                   checksum=crc_b(msg))

        fragment1 = FragmentHeader(Protocol.DATAGRAM, FragmentFlags.NONE, 0,
                                   99)

        header1 = NetworkHeader(version=0,
                                protocol=Protocol.FRAGMENT,
                                qos=QoS.Default,
                                ttl=4,
                                identity=42,
                                length=fragment1.size() + datagram1.size() +
                                len(msg),
                                source=MeshAddress(1),
                                destination=MeshAddress(2))

        stream = BytesIO()
        header1.encode(stream)
        fragment1.encode(stream)
        datagram1.encode(stream)
        stream.write(msg)
        stream.seek(0)

        header2 = NetworkHeader.decode(stream)
        fragment = FragmentHeader.decode(stream)
        datagram = DatagramHeader.decode(stream)
        msg2 = stream.read()
        self.assertEqual(header1, header2)
        self.assertEqual(datagram1, datagram)
        self.assertEqual(msg2, msg)
Ejemplo n.º 4
0
def encode_packet(network_header: NetworkHeader,
                  additional_headers: List[Header], payload: bytes) -> bytes:
    stream = BytesIO()
    network_header.encode(stream)
    for header in additional_headers:
        header.encode(stream)
    stream.write(payload)
    stream.seek(0)
    return stream.read()
Ejemplo n.º 5
0
    def fragment_datagram(self, source: MeshAddress, destination: MeshAddress,
                          datagram_header: DatagramHeader, data: bytes):
        stream = BytesIO()
        datagram_header.encode(stream)
        stream.write(data)
        stream.seek(0)
        fragment_size = self.network.mtu() - FragmentHeader.size()
        fragments = list(chunks(stream.read(), fragment_size))
        with self.send_seq:
            sequences = range(self.send_seq, len(fragments))
            self.send_seq += len(fragments)

        for i, fragment in zip(range(len(fragments)), fragments):
            if i < len(fragments) - 1:
                flags = FragmentFlags.FRAGMENT
            else:
                flags = FragmentFlags.NONE
            fragment_header = FragmentHeader(protocol=Protocol.DATAGRAM,
                                             flags=flags,
                                             fragment=i,
                                             sequence=sequences[i])
            network_header = NetworkHeader(
                version=0,
                qos=QoS.Default,
                protocol=Protocol.FRAGMENT,
                ttl=MeshProtocol.DefaultTTL,
                identity=struct.unpack('<H', secrets.token_bytes(2))[0],
                length=len(fragment),
                source=source,
                destination=destination,
            )
            buffer = encode_packet(network_header, [fragment_header], fragment)
            self.network.send(network_header, buffer)
Ejemplo n.º 6
0
    def handle_l4(self, network_header: NetworkHeader, stream: BytesIO):
        broadcast_header = BroadcastHeader.decode(stream)
        self.debug(f"Handling {broadcast_header} {network_header}")

        if self.ttl_cache.contains(
                hash((broadcast_header.source, broadcast_header.sequence))):
            return

        data = stream.read()

        # Now forward it to everyone we've heard from recently except who we heard it from
        neighbors = self.network.alive_neighbors()
        for neighbor in neighbors:
            if neighbor == broadcast_header.source:
                continue
            if neighbor == network_header.source:
                continue
            network_header = NetworkHeader(
                version=0,
                qos=QoS.Default,
                protocol=Protocol.BROADCAST,
                ttl=1,
                identity=self.network.next_sequence(),
                length=0,
                source=self.network.our_address,
                destination=neighbor,
            )
            buffer = encode_partial_packet([broadcast_header], data)
            self.info(f"Forwarding {broadcast_header} to {neighbor}: {data}")
            self.reliable_manager.send(network_header, buffer)

        # Now deliver it locally
        self.info(f"Delivering {broadcast_header}: {data}")
        self.transport.handle_broadcast(broadcast_header.source,
                                        broadcast_header.port, data)
Ejemplo n.º 7
0
 def send_datagram(self,
                   source: MeshAddress,
                   destination: MeshAddress,
                   datagram_header: DatagramHeader,
                   data: bytes,
                   reliable: bool = False):
     if reliable:
         network_header = NetworkHeader(
             version=0,
             qos=QoS.Default,
             protocol=Protocol.RELIABLE,
             ttl=MeshProtocol.DefaultTTL,
             identity=struct.unpack('<H', secrets.token_bytes(2))[0],
             length=0,
             source=source,
             destination=destination,
         )
         reliable_header = ReliableHeader(protocol=Protocol.DATAGRAM,
                                          flags=ReliableFlags.ACK,
                                          sequence=network_header.identity,
                                          acknowledged=[])
         if len(data) + DatagramHeader.size() + reliable_header.size(
         ) > self.network.mtu():
             raise RuntimeError(
                 "Fragmentation not supported for reliable protocol")
         buffer = BytesIO()
         network_header.encode(buffer)
         reliable_header.encode(buffer)
         datagram_header.encode(buffer)
         buffer.write(data)
         buffer.seek(0)
         self.reliable_protocol.send(network_header, reliable_header,
                                     buffer.read())
     else:
         if len(data) + DatagramHeader.size() > self.network.mtu():
             self.fragment_protocol.fragment_datagram(
                 source, destination, datagram_header, data)
         else:
             network_header = NetworkHeader(
                 version=0,
                 qos=QoS.Default,
                 protocol=Protocol.DATAGRAM,
                 ttl=MeshProtocol.DefaultTTL,
                 identity=struct.unpack('<H', secrets.token_bytes(2))[0],
                 length=0,
                 source=source,
                 destination=destination,
             )
             buffer = BytesIO()
             network_header.encode(buffer)
             datagram_header.encode(buffer)
             buffer.write(data)
             buffer.seek(0)
             self.network.send(network_header, buffer.read())
Ejemplo n.º 8
0
    def handle_l4(self, network_header: NetworkHeader, stream: BytesIO):
        reliable_header = ReliableHeader.decode(stream)
        self.debug(f"Handling reliable {reliable_header}")

        if reliable_header.flags & ReliableFlags.ACK:
            # We received an ACK
            self.reliable_manager.handle_ack(network_header.source,
                                             reliable_header.sequence)
        else:
            # An ACK is being requested
            if network_header.source not in self.network.alive_neighbors():
                # Ignore if we haven't seen this neighbor before, we need discovery first
                # TODO send discovery to this neighbor
                self.debug(
                    f"Ignoring {reliable_header} since we don't know this neighbor."
                )
                return
            outgoing_network_header = NetworkHeader(
                version=0,
                qos=QoS.Default,
                protocol=Protocol.RELIABLE,
                ttl=1,
                identity=self.network.next_sequence(),
                length=0,
                source=self.network.our_address,
                destination=network_header.source,
            )
            ack_header = ReliableHeader(
                protocol=Protocol.NONE,
                flags=ReliableFlags.ACK,
                sequence=reliable_header.sequence,
            )
            response = BytesIO()
            outgoing_network_header.encode(response)
            ack_header.encode(response)
            response.seek(0)
            buffer = response.read()
            self.info(
                f"Sending Ack for sequence {reliable_header.sequence} to address {outgoing_network_header.source}"
            )
            self.network.send(outgoing_network_header, buffer)
            self.handlers.handle_l4(network_header, reliable_header.protocol,
                                    stream)
Ejemplo n.º 9
0
    def test_ttl_cache(self):
        time = MockTime()
        cache = TTLCache(time, 10)

        header = NetworkHeader(version=0,
                               protocol=Protocol.DATAGRAM,
                               qos=QoS.Default,
                               ttl=4,
                               identity=42,
                               length=0,
                               source=MeshAddress(1),
                               destination=MeshAddress(2))

        self.assertFalse(cache.contains(hash(header)))
        self.assertTrue(cache.contains(hash(header)))

        same_header = NetworkHeader(version=0,
                                    protocol=Protocol.DATAGRAM,
                                    qos=QoS.Default,
                                    ttl=4,
                                    identity=42,
                                    length=0,
                                    source=MeshAddress(1),
                                    destination=MeshAddress(2))

        self.assertTrue(cache.contains(hash(same_header)))

        diff_header = NetworkHeader(version=0,
                                    protocol=Protocol.DATAGRAM,
                                    qos=QoS.Default,
                                    ttl=4,
                                    identity=43,
                                    length=0,
                                    source=MeshAddress(1),
                                    destination=MeshAddress(2))

        self.assertFalse(cache.contains(hash(diff_header)))
        self.assertTrue(cache.contains(hash(diff_header)))

        time.sleep(11)
        self.assertFalse(cache.contains(hash(header)))
        self.assertFalse(cache.contains(hash(diff_header)))
Ejemplo n.º 10
0
    def send_hello(self):
        self.debug("Sending Hello")
        hello = HelloHeader(self.config.get("host.name"),
                            self.alive_neighbors())
        network_header = NetworkHeader(
            version=0,
            qos=QoS.Lower,
            protocol=Protocol.HELLO,
            ttl=1,
            identity=self.next_sequence(),
            length=0,
            source=self.our_address,
            destination=self.BroadcastAddress,
        )

        stream = BytesIO()
        network_header.encode(stream)
        hello.encode(stream)
        stream.seek(0)
        buffer = stream.read()
        self.send(network_header, buffer)
Ejemplo n.º 11
0
    def test_encode_decode_datagram(self):
        msg = "Hello, World!".encode("utf-8")
        datagram_header_1 = DatagramHeader(source=100,
                                           destination=100,
                                           length=len(msg),
                                           checksum=crc_b(msg))

        header1 = NetworkHeader(version=0,
                                protocol=Protocol.DATAGRAM,
                                qos=QoS.Default,
                                ttl=4,
                                identity=42,
                                length=datagram_header_1.size() + len(msg),
                                source=MeshAddress(1),
                                destination=MeshAddress(2))

        data = encode_packet(header1, [datagram_header_1], msg)
        stream = BytesIO(data)
        header2 = NetworkHeader.decode(stream)
        datagram2 = DatagramHeader.decode(stream)
        self.assertEqual(header1, header2)
        self.assertEqual(datagram_header_1, datagram2)
Ejemplo n.º 12
0
 def send_ping(self, node: MeshAddress) -> int:
     ctrl = ControlHeader(True, ControlType.PING, 1,
                          bytes([secure_random_byte()]))
     network_header = NetworkHeader(
         version=0,
         qos=QoS.Higher,
         protocol=Protocol.CONTROL,
         ttl=7,
         identity=self.network.next_sequence(),
         length=0,
         source=self.network.our_address,
         destination=node,
     )
     stream = BytesIO()
     network_header.encode(stream)
     ctrl.encode(stream)
     stream.seek(0)
     buffer = stream.read()
     self.stats[node].results.append(
         PingResult(ctrl.extra[0], time.time_ns(), None,
                    threading.Condition(self.mutex)))
     self.network.send(network_header, buffer)
     return ctrl.extra[0]
    def test_send_receive(self):
        network_header = NetworkHeader(
            version=0,
            qos=QoS.Lower,
            protocol=Protocol.DATAGRAM,
            ttl=3,
            identity=10,
            length=0,
            source=MeshAddress(1),
            destination=MeshAddress(4),
        )

        msg = "Hello, Node 4".encode("utf-8")
        datagram_header = DatagramHeader(source=100,
                                         destination=100,
                                         length=len(msg),
                                         checksum=0)
        stream = BytesIO()
        network_header.encode(stream)
        datagram_header.encode(stream)
        stream.write(msg)
        stream.seek(0)

        captured = None

        class MockTransportManager(L4Handler):
            def handle_l4(self, network_header: NetworkHeader,
                          stream: BytesIO):
                nonlocal captured
                DatagramHeader.decode(stream)
                captured = stream.read()

        self.node_4.l4_handlers.handlers[
            Protocol.DATAGRAM] = MockTransportManager()
        self.node_1.send(network_header, stream.read())
        self.for_each_node(self.drain_queue)
        self.assertEqual(captured.decode("utf-8"), "Hello, Node 4")
Ejemplo n.º 14
0
 def send_broadcast(self, port: int, data: bytes):
     neighbors = self.network.alive_neighbors()
     if len(neighbors) == 0:
         self.debug(f"Skipping broadcast since we have no neighbors.")
     broadcast_header = BroadcastHeader(source=self.network.our_address,
                                        port=port,
                                        sequence=self.next_sequence(),
                                        length=len(data),
                                        checksum=crc_b(data))
     for neighbor in neighbors:
         network_header = NetworkHeader(
             version=0,
             qos=QoS.Default,
             protocol=Protocol.BROADCAST,
             ttl=1,
             identity=self.network.next_sequence(),
             length=0,
             source=self.network.our_address,
             destination=neighbor,
         )
         buffer = encode_partial_packet([broadcast_header], data)
         self.info(f"Broadcasting {broadcast_header} to {neighbor}: {data}")
         self.reliable_manager.send(network_header, buffer)
Ejemplo n.º 15
0
    def check_acks(self):
        with self.not_empty:
            while len(self.sent) == 0:
                self.not_empty.wait()

            item = self.sent[0]
            item.attempt += 1
            if item.attempt > ReliableManager.MaxResend:
                self.info(f"Expiring {item}")
                self.sent.remove(item)
                self.network.failed_send(item.header.destination)
                self.not_full.notify()
            else:
                self.info(f"Resending {item}")
                stream = BytesIO(item.buffer)
                network_header = NetworkHeader.decode(stream)
                network_header = dataclasses.replace(
                    network_header, identity=self.network.next_sequence())
                stream.seek(0)
                network_header.encode(stream)
                stream.seek(0)
                self.network.send(network_header, stream.read())
                # TODO backoff
            self.timer.reset()
Ejemplo n.º 16
0
    def process_incoming(self, payload: L2Payload):
        stream = BytesIO(payload.l3_data)

        try:
            network_header = NetworkHeader.decode(stream)
        except Exception as e:
            self.error(f"Could not decode network packet from {payload}.", e)
            return

        # Handle L3 protocols first
        if network_header.destination == self.our_address and network_header.protocol == Protocol.CONTROL:
            ctrl = ControlHeader.decode(stream)
            self.info(f"Got {ctrl} from {network_header.source}")
            if ctrl.control_type == ControlType.PING:
                self.ping_protocol.handle_ping(network_header, ctrl)
            else:
                self.warning(
                    f"Ignoring unsupported control packet: {ctrl.control_type}"
                )
            return

        if network_header.protocol == Protocol.HELLO:
            self.handle_hello(payload.link_id, network_header,
                              HelloHeader.decode(stream))
            return

        if network_header.protocol == Protocol.LINK_STATE:
            self.handle_advertisement(
                payload.link_id, network_header,
                LinkStateAdvertisementHeader.decode(stream))
            return

        if network_header.protocol == Protocol.LINK_STATE_QUERY:
            self.handle_query(payload.link_id, network_header,
                              LinkStateQueryHeader.decode(stream))
            return

        # Now decide if we should handle or drop
        if self.header_cache.contains(hash(network_header)):
            self.debug(f"Dropping duplicate {network_header}")
            return

        # If the packet is addressed to us, handle it
        if network_header.destination == self.our_address:
            self.debug(f"Handling {network_header}")
            self.l4_handlers.handle_l4(network_header, network_header.protocol,
                                       stream)
            return

        if network_header.destination == self.BroadcastAddress:
            self.debug(f"Handling broadcast {network_header}")
            self.l4_handlers.handle_l4(network_header, network_header.protocol,
                                       stream)
            if network_header.ttl > 1:
                # Decrease the TTL and re-broadcast on all links except where we heard it
                header_copy = dataclasses.replace(network_header,
                                                  ttl=network_header.ttl - 1)
                stream.seek(0)
                header_copy.encode(stream)
                stream.seek(0)
                self.send(header_copy,
                          stream.read(),
                          exclude_link_id=payload.link_id)
            else:
                self.debug("Not re-broadcasting due to TTL")
        else:
            header_copy = dataclasses.replace(network_header,
                                              ttl=network_header.ttl - 1)
            stream.seek(0)
            header_copy.encode(stream)
            stream.seek(0)
            self.send(header_copy,
                      stream.read(),
                      exclude_link_id=payload.link_id)
Ejemplo n.º 17
0
    def handle_query(self, link_id: int, network_header: NetworkHeader,
                     query: LinkStateQueryHeader):
        self.debug(f"Handling {query}")
        dest = network_header.source

        adverts = []
        # Check our local cache of link states
        for node, link_states in self.valid_link_states().items():
            if node == dest:
                continue

            if node not in query.link_nodes:
                # Sender is missing this node's data
                header = LinkStateAdvertisementHeader(
                    node=node,
                    name=self.host_names.get(node, "unknown"),
                    epoch=self.link_state_epochs.get(node),
                    link_states=list(self.link_states.get(node)))
                adverts.append(header)
            else:
                idx = query.link_nodes.index(node)
                epoch = query.link_epochs[idx]
                if lollipop_compare(epoch, self.link_state_epochs[node]) > -1:
                    header = LinkStateAdvertisementHeader(
                        node=node,
                        name=self.host_names.get(node, "unknown"),
                        epoch=self.link_state_epochs.get(node),
                        link_states=list(self.link_states.get(node)))
                    adverts.append(header)

        # Include our latest state if they don't have it
        if self.our_address in query.link_nodes:
            idx = query.link_nodes.index(self.our_address)
            epoch = query.link_epochs[idx]
            if lollipop_compare(epoch, self.our_link_state_epoch) > -1:
                link_states = self.generate_our_adverts()

                our_advert = LinkStateAdvertisementHeader(
                    node=self.our_address,
                    name=self.host_name,
                    epoch=self.our_link_state_epoch,
                    link_states=link_states)
                adverts.append(our_advert)

        dest = network_header.source
        for advert in adverts:
            self.debug(f"Sending {advert} advert to {network_header.source}")
            resp_header = NetworkHeader(
                version=0,
                qos=QoS.Lower,
                protocol=Protocol.LINK_STATE,
                ttl=1,
                identity=self.next_sequence(),
                length=0,
                source=self.our_address,
                destination=dest,
            )

            stream = BytesIO()
            resp_header.encode(stream)
            advert.encode(stream)
            stream.seek(0)
            buffer = stream.read()
            self.send(resp_header, buffer)