Ejemplo n.º 1
0
def on_message(client: mqtt.Client, userdata: Any,
               message: mqtt.MQTTMessage) -> None:
    """Processes messages from MQTT and forwards them to netlink.

    Arguments:
        client: the client instance for this callback.
        userdata: the private user data.
        message: The MQTT message.
    """
    # TODO(ruairi): Check bounds and raise exception here.
    logger.debug("Got message %s from MTQQ", message)
    domain_prefix = load_config().get("domain_prefix")
    domain = re.search(r"/.*" + domain_prefix + "(\w+)/", message.topic)
    if not domain:
        raise ValueError("Could not find a match for %s on %s", domain_prefix,
                         message.topic)
    domain = domain.group(1)
    logger.debug("Found domain %s", domain)
    client = WireGuardClient(
        public_key=str(message.payload.decode("utf-8")),
        domain=domain,
        remove=False,
    )
    logger.info(
        f"Received create message for key {client.public_key} on domain {domain} with lladdr {client.lladdr}"
    )
    # TODO(ruairi): Verify return type here.
    logger.debug(link_handler(client))
Ejemplo n.º 2
0
def on_connect(client: mqtt.Client, userdata: Any, flags, rc) -> None:
    """Handles MQTT connect and subscribes to topics on connect

    Arguments:
        client: the client instance for this callback.
        userdata: the private user data.
        flags: The MQTT flags.
        rc: The MQTT rc.
    """
    logger.debug("Connected with result code " + str(rc))
    domains = load_config().get("domains")

    # Subscribing in on_connect() means that if we lose the connection and
    # reconnect then subscriptions will be renewed.
    for domain in domains:
        topic = f"wireguard/{domain}/+"
        logger.info(f"Subscribing to topic {topic}")
        client.subscribe(topic)
Ejemplo n.º 3
0
def wg_key_exchange() -> Tuple[str, int]:
    """Retrieves a new key and validates.

    Returns:
        Status message.
    """
    try:
        data = KeyExchange.from_dict(request.get_json(force=True))
    except TypeError as ex:
        return abort(400, jsonify({"error": {"message": str(ex)}}))

    key = data.public_key
    domain = data.domain
    # in case we want to decide here later we want to publish it only to dedicated gateways
    gateway = "all"
    logger.info(f"wg_key_exchange: Domain: {domain}, Key:{key}")

    mqtt.publish(f"wireguard/{domain}/{gateway}", key)
    return jsonify({"Message": "OK"}), 200
Ejemplo n.º 4
0
def wg_flush_stale_peers(domain: str) -> List[Dict]:
    """Removes stale peers.

    Arguments:
        domain: The domain to detect peers on.

    Returns:
        The peers which we can remove.
    """
    logger.info("Searching for stale clients for %s", domain)
    stale_clients = [
        stale_client for stale_client in find_stale_wireguard_clients("wg-" + domain)
    ]
    logger.debug("Found stale clients: %s", stale_clients)
    logger.info("Searching for stale WireGuard clients.")
    stale_wireguard_clients = [
        WireGuardClient(public_key=stale_client, domain=domain, remove=True)
        for stale_client in stale_clients
    ]
    logger.debug("Found stable WireGuard clients: %s", stale_wireguard_clients)
    logger.info("Processing clients.")
    link_handled = [
        link_handler(stale_client) for stale_client in stale_wireguard_clients
    ]
    logger.debug("Handled the following clients: %s", link_handled)
    return link_handled
Ejemplo n.º 5
0
def find_stale_wireguard_clients(wg_interface: str) -> List:
    """Fetches and returns a list of peers which have not had recent handshakes.

    Arguments:
        wg_interface: The WireGuard interface to query.

    Returns:
        # A list of peers which have not recently seen a handshake.
    """
    three_hrs_in_secs = int(
        (datetime.now() - timedelta(hours=_PEER_TIMEOUT_HOURS)).timestamp()
    )
    logger.info(
        "Starting search for stale wireguard peers for interface %s.", wg_interface
    )
    with pyroute2.WireGuard() as wg:
        all_clients = []
        peers_on_interface = wg.info(wg_interface)
        logger.info("Got infos: %s.", peers_on_interface)
        for peer in peers_on_interface:
            clients = peer.get_attr("WGDEVICE_A_PEERS")
            logger.info("Got clients: %s.", clients)
            if clients:
                all_clients.extend(clients)
        ret = [
            client.get_attr("WGPEER_A_PUBLIC_KEY").decode("utf-8")
            for client in all_clients
            if client.get_attr("WGPEER_A_LAST_HANDSHAKE_TIME").get("tv_sec", int())
            < three_hrs_in_secs
        ]
        return ret
Ejemplo n.º 6
0
def clean_up_worker(domains: List[Text]) -> None:
    """Wraps flush_workers in a thread for all given domains.

    Arguments:
        domains: list of domains.
    """
    logger.debug("Cleaning up the following domains: %s", domains)
    prefix = config.load_config().get("domain_prefix")
    for domain in domains:
        logger.info("Scheduling cleanup task for %s, ", domain)
        try:
            cleaned_domain = domain.split(prefix)[1]
        except IndexError:
            logger.error(
                "Cannot strip domain with prefix %s from passed value %s. Skipping cleanup operation",
                prefix,
                domain,
            )
            continue
        thread = threading.Thread(target=flush_workers,
                                  args=(cleaned_domain, ))
        thread.start()
Ejemplo n.º 7
0
def connect() -> None:
    """Connect to MQTT for the given domains.

    Argument:
        domains: The domains to connect to.
    """
    base_config = fetch_from_config("mqtt")
    broker_address = base_config.get("broker_url")
    broker_port = base_config.get("broker_port")
    broker_keepalive = base_config.get("keepalive")
    # TODO(ruairi): Move the hostname to a global variable.
    client = mqtt.Client(socket.gethostname())

    # Register handlers
    client.on_connect = on_connect
    client.on_message = on_message
    logger.info("connecting to broker %s", broker_address)

    client.connect(broker_address,
                   port=broker_port,
                   keepalive=broker_keepalive)
    client.loop_forever()
Ejemplo n.º 8
0
def link_handler(client: WireGuardClient) -> Dict:
    """Updates fdb, route and WireGuard peers tables for a given WireGuard peer.

    Arguments:
        client: A WireGuard peer to manipulate.
    Returns:
        The outcome of each operation.
    """
    results = dict()
    # Updates WireGuard peers.
    results.update({"Wireguard": update_wireguard_peer(client)})
    logger.debug("Handling links for %s", client)
    try:
        # Updates routes to the WireGuard Peer.
        results.update({"Route": route_handler(client)})
        logger.info("Updated route for %s", client)
    except Exception as e:
        # TODO(ruairi): re-raise exception here.
        logger.error("Failed to update route for %s (%s)", client, e)
        results.update({"Route": e})
    # Updates WireGuard FDB.
    results.update({"Bridge FDB": bridge_fdb_handler(client)})
    logger.debug("Updated Bridge FDB for %s", client)
    return results
Ejemplo n.º 9
0
def flush_workers(domain: Text) -> None:
    """Calls peer flush every _CLEANUP_TIME interval."""
    while True:
        time.sleep(_CLEANUP_TIME)
        logger.info(f"Running cleanup task for {domain}")
        logger.info("Cleaned up domains: %s", wg_flush_stale_peers(domain))