예제 #1
0
    async def create(cls,
                     node_id: DHTID,
                     bucket_size: int,
                     depth_modulo: int,
                     num_replicas: int,
                     wait_timeout: float,
                     parallel_rpc: Optional[int] = None,
                     cache_size: Optional[int] = None,
                     listen=True,
                     listen_on='0.0.0.0:*',
                     channel_options: Optional[Sequence[Tuple[str,
                                                              Any]]] = None,
                     **kwargs) -> DHTProtocol:
        """
        A protocol that allows DHT nodes to request keys/neighbors from other DHT nodes.
        As a side-effect, DHTProtocol also maintains a routing table as described in
        https://pdos.csail.mit.edu/~petar/papers/maymounkov-kademlia-lncs.pdf

        See DHTNode (node.py) for a more detailed description.

        :note: the rpc_* methods defined in this class will be automatically exposed to other DHT nodes,
         for instance, def rpc_ping can be called as protocol.call_ping(endpoint, dht_id) from a remote machine
         Only the call_* methods are meant to be called publicly, e.g. from DHTNode
         Read more: https://github.com/bmuller/rpcudp/tree/master/rpcudp
        """
        self = cls(_initialized_with_create=True)
        self.node_id, self.bucket_size, self.num_replicas = node_id, bucket_size, num_replicas
        self.wait_timeout, self.channel_options = wait_timeout, channel_options
        self.storage, self.cache = LocalStorage(), LocalStorage(
            maxsize=cache_size)
        self.routing_table = RoutingTable(node_id, bucket_size, depth_modulo)
        self.rpc_semaphore = asyncio.Semaphore(
            parallel_rpc if parallel_rpc is not None else float('inf'))

        if listen:  # set up server to process incoming rpc requests
            grpc.experimental.aio.init_grpc_aio()
            self.server = grpc.experimental.aio.server(**kwargs)
            dht_grpc.add_DHTServicer_to_server(self, self.server)

            found_port = self.server.add_insecure_port(listen_on)
            assert found_port != 0, f"Failed to listen to {listen_on}"
            self.node_info = dht_pb2.NodeInfo(node_id=node_id.to_bytes(),
                                              rpc_port=found_port)
            self.port = found_port
            await self.server.start()
        else:  # not listening to incoming requests, client-only mode
            # note: use empty node_info so peers wont add you to their routing tables
            self.node_info, self.server, self.port = dht_pb2.NodeInfo(
            ), None, None
            if listen_on != '0.0.0.0:*' or len(kwargs) != 0:
                warn(
                    f"DHTProtocol has no server (due to listen=False), listen_on"
                    f"and kwargs have no effect (unused kwargs: {kwargs})")
        return self
예제 #2
0
 def __init__(self, sourceNode, storage, ksize):
     self.ksize = ksize
     self.router = RoutingTable(self, ksize, sourceNode)
     self.storage = storage
     self.sourceNode = sourceNode
     self.multiplexer = None
     self.log = Logger(system=self)
     self.handled_commands = [
         PING, STUN, STORE, DELETE, FIND_NODE, FIND_VALUE, HOLE_PUNCH
     ]
     RPCProtocol.__init__(self, sourceNode.getProto(), self.router)
예제 #3
0
 def __init__(self, sourceNode, storage, ksize, database, signing_key):
     self.ksize = ksize
     self.router = RoutingTable(self, ksize, sourceNode)
     self.storage = storage
     self.sourceNode = sourceNode
     self.multiplexer = None
     self.db = database
     self.signing_key = signing_key
     self.log = Logger(system=self)
     self.handled_commands = [PING, STUN, STORE, DELETE, FIND_NODE, FIND_VALUE, HOLE_PUNCH, INV, VALUES]
     self.recent_transfers = set()
     RPCProtocol.__init__(self, sourceNode, self.router)
예제 #4
0
def test_routing_table_search():
    for table_size, lower_active, upper_active in [(10, 10, 10),
                                                   (10_000, 800, 1100)]:
        node_id = DHTID.generate()
        routing_table = RoutingTable(node_id, bucket_size=20, depth_modulo=5)
        num_added = 0
        total_nodes = 0

        for phony_neighbor_port in random.sample(range(1_000_000), table_size):
            routing_table.add_or_update_node(
                DHTID.generate(), f'{LOCALHOST}:{phony_neighbor_port}')
            new_total = sum(
                len(bucket.nodes_to_endpoint)
                for bucket in routing_table.buckets)
            num_added += new_total > total_nodes
            total_nodes = new_total
예제 #5
0
 def setUp(self):
     self.catcher = []
     observer = self.catcher.append
     log.addObserver(observer)
     self.addCleanup(log.removeObserver, observer)
     self.node = Node(digest("test"), "127.0.0.1", 1234)
     self.router = RoutingTable(self, 20, self.node.id)
예제 #6
0
파일: appstate.py 프로젝트: tomcur/otdht
    def prepare():
        """
        Prepare the application state.
        """
        thisNodeIP = config.NODE_IP
        thisNodePort = config.NODE_PORT
        thisNodeHash = Hash(sha1(config.NODE_ID_NAME).digest())
        AppState.thisNode = Node(thisNodeHash, (thisNodeIP, thisNodePort))

        AppState.heartbeat = config.HEARTBEAT

        AppState.tokenSecret = utils.randomBits(160)

        AppState.maxPeersPerTorrent = config.MAX_PEERS_PER_TORRENT

        AppState.k = config.K
        AppState.maxNodesPerPucket = config.MAX_NODES_PER_BUCKET

        AppState.routingTable = RoutingTable()

        if config.PEER_STORAGE == 'file':
            AppState.peerStorage = dht.peerstorage.FilePeerStorage(
                config.PEER_STORAGE_DIR)
        elif config.PEER_STORAGE == 'mysql':
            AppState.peerStorage = dht.peerstorage.MySQLPeerStorage()

        # {transactionID: {(RPCQuery, Node, timestamp)}}
        AppState.outstandingQueries = {}
예제 #7
0
 def __init__(self, sourceNode, storage, ksize):
     self.ksize = ksize
     self.router = RoutingTable(self, ksize, sourceNode)
     self.storage = storage
     self.sourceNode = sourceNode
     self.multiplexer = None
     self.log = Logger(system=self)
     self.handled_commands = [PING, STUN, STORE, DELETE, FIND_NODE, FIND_VALUE]
     RPCProtocol.__init__(self, sourceNode.getProto(), self.router)
예제 #8
0
 def __init__(self, sourceNode, storage, ksize, database):
     self.ksize = ksize
     self.router = RoutingTable(self, ksize, sourceNode)
     self.storage = storage
     self.sourceNode = sourceNode
     self.multiplexer = None
     self.db = database
     self.log = Logger(system=self)
     self.handled_commands = [PING, STUN, STORE, DELETE, FIND_NODE, FIND_VALUE, HOLE_PUNCH, INV, VALUES]
     RPCProtocol.__init__(self, sourceNode, self.router)
예제 #9
0
def test_routing_table_parameters():
    for (bucket_size, modulo, min_nbuckets, max_nbuckets) in [
        (20, 5, 45, 65),
        (50, 5, 35, 45),
        (20, 10, 650, 800),
        (20, 1, 7, 15),
    ]:
        node_id = DHTID.generate()
        routing_table = RoutingTable(node_id,
                                     bucket_size=bucket_size,
                                     depth_modulo=modulo)
        for phony_neighbor_port in random.sample(range(1_000_000), 10_000):
            routing_table.add_or_update_node(
                DHTID.generate(), f'{LOCALHOST}:{phony_neighbor_port}')
        for bucket in routing_table.buckets:
            assert len(bucket.replacement_nodes) == 0 or len(
                bucket.nodes_to_endpoint) <= bucket.size
        assert min_nbuckets <= len(routing_table.buckets) <= max_nbuckets, (
            f"Unexpected number of buckets: {min_nbuckets} <= {len(routing_table.buckets)} <= {max_nbuckets}"
        )
예제 #10
0
class RoutingTableTest(unittest.TestCase):
    def setUp(self):
        self.node = Node(digest("test"), "127.0.0.1", 1234)
        self.router = RoutingTable(self, 20, self.node.id)

    def test_addContact(self):
        self.router.addContact(mknode())
        self.assertTrue(len(self.router.buckets), 1)
        self.assertTrue(len(self.router.buckets[0].nodes), 1)

    def test_addDuplicate(self):
        self.router.addContact(self.node)
        self.router.addContact(self.node)
        self.assertTrue(len(self.router.buckets), 1)
        self.assertTrue(len(self.router.buckets[0].nodes), 1)

    def test_addSameIP(self):
        self.router.addContact(self.node)
        self.router.addContact(Node(digest("asdf"), "127.0.0.1", 1234))
        self.assertTrue(len(self.router.buckets), 1)
        self.assertTrue(len(self.router.buckets[0].nodes), 1)
        self.assertTrue(self.router.buckets[0].getNodes()[0].id == digest("asdf"))
예제 #11
0
class RoutingTableTest(unittest.TestCase):
    def setUp(self):
        self.node = Node(digest("test"), "127.0.0.1", 1234)
        self.router = RoutingTable(self, 20, self.node.id)

    def test_addContact(self):
        self.router.addContact(mknode())
        self.assertTrue(len(self.router.buckets), 1)
        self.assertTrue(len(self.router.buckets[0].nodes), 1)

    def test_addDuplicate(self):
        self.router.addContact(self.node)
        self.router.addContact(self.node)
        self.assertTrue(len(self.router.buckets), 1)
        self.assertTrue(len(self.router.buckets[0].nodes), 1)

    def test_addSameIP(self):
        self.router.addContact(self.node)
        self.router.addContact(Node(digest("asdf"), "127.0.0.1", 1234))
        self.assertTrue(len(self.router.buckets), 1)
        self.assertTrue(len(self.router.buckets[0].nodes), 1)
        self.assertTrue(
            self.router.buckets[0].getNodes()[0].id == digest("asdf"))
예제 #12
0
def test_routing_table_basic():
    node_id = DHTID.generate()
    routing_table = RoutingTable(node_id, bucket_size=20, depth_modulo=5)
    added_nodes = []

    for phony_neighbor_port in random.sample(range(10000), 100):
        phony_id = DHTID.generate()
        routing_table.add_or_update_node(phony_id,
                                         f'{LOCALHOST}:{phony_neighbor_port}')
        assert phony_id in routing_table
        assert f'{LOCALHOST}:{phony_neighbor_port}' in routing_table
        assert routing_table[phony_id] == f'{LOCALHOST}:{phony_neighbor_port}'
        assert routing_table[f'{LOCALHOST}:{phony_neighbor_port}'] == phony_id
        added_nodes.append(phony_id)

    assert routing_table.buckets[
        0].lower == DHTID.MIN and routing_table.buckets[-1].upper == DHTID.MAX
    for bucket in routing_table.buckets:
        assert len(
            bucket.replacement_nodes
        ) == 0, "There should be no replacement nodes in a table with 100 entries"
    assert 3 <= len(routing_table.buckets) <= 10, len(routing_table.buckets)

    random_node = random.choice(added_nodes)
    assert routing_table.get(node_id=random_node) == routing_table[random_node]
    dummy_node = DHTID.generate()
    assert (dummy_node
            not in routing_table) == (routing_table.get(node_id=dummy_node) is
                                      None)

    for node in added_nodes:
        found_bucket_index = routing_table.get_bucket_index(node)
        for bucket_index, bucket in enumerate(routing_table.buckets):
            if bucket.lower <= node < bucket.upper:
                break
        else:
            raise ValueError(
                "Naive search could not find bucket. Universe has gone crazy.")
        assert bucket_index == found_bucket_index
예제 #13
0
class KademliaProtocol(RPCProtocol):
    implements(MessageProcessor)

    def __init__(self, sourceNode, storage, ksize, database, signing_key):
        self.ksize = ksize
        self.router = RoutingTable(self, ksize, sourceNode)
        self.storage = storage
        self.sourceNode = sourceNode
        self.multiplexer = None
        self.db = database
        self.signing_key = signing_key
        self.log = Logger(system=self)
        self.handled_commands = [
            PING, STUN, STORE, DELETE, FIND_NODE, FIND_VALUE, HOLE_PUNCH, INV,
            VALUES
        ]
        RPCProtocol.__init__(self, sourceNode, self.router)

    def connect_multiplexer(self, multiplexer):
        self.multiplexer = multiplexer

    def getRefreshIDs(self):
        """
        Get ids to search for to keep old buckets up to date.
        """
        ids = []
        for bucket in self.router.getLonelyBuckets():
            ids.append(random.randint(*bucket.range))
        return ids

    def rpc_stun(self, sender):
        self.addToRouter(sender)
        return [sender.ip, str(sender.port)]

    def rpc_ping(self, sender):
        self.addToRouter(sender)
        return [self.sourceNode.getProto().SerializeToString()]

    def rpc_store(self, sender, keyword, key, value, ttl):
        self.addToRouter(sender)
        self.log.debug("got a store request from %s, storing value" %
                       str(sender))
        if len(keyword) == 20 and len(key) <= 33 and len(
                value) <= 2100 and int(ttl) <= 604800:
            self.storage[keyword] = (key, value, int(ttl))
            return ["True"]
        else:
            return ["False"]

    def rpc_delete(self, sender, keyword, key, signature):
        self.addToRouter(sender)
        value = self.storage.getSpecific(keyword, key)
        if value is not None:
            # Try to delete a message from the dht
            if keyword == digest(sender.id):
                try:
                    verify_key = nacl.signing.VerifyKey(sender.pubkey)
                    verify_key.verify(key, signature)
                    self.storage.delete(keyword, key)
                    return ["True"]
                except Exception:
                    return ["False"]
            # Or try to delete a pointer
            else:
                try:
                    node = objects.Node()
                    node.ParseFromString(value)
                    pubkey = node.publicKey
                    try:
                        verify_key = nacl.signing.VerifyKey(pubkey)
                        verify_key.verify(key, signature)
                        self.storage.delete(keyword, key)
                        return ["True"]
                    except Exception:
                        return ["False"]
                except Exception:
                    pass
        return ["False"]

    def rpc_find_node(self, sender, key):
        self.log.debug("finding neighbors of %s in local table" %
                       key.encode('hex'))
        self.addToRouter(sender)
        node = Node(key)
        nodeList = self.router.findNeighbors(node, exclude=sender)
        ret = []
        if self.sourceNode.id == key:
            ret.append(self.sourceNode.getProto().SerializeToString())
        for n in nodeList:
            ret.append(n.getProto().SerializeToString())
        return ret

    def rpc_find_value(self, sender, keyword):
        self.addToRouter(sender)
        ret = ["value"]
        value = self.storage.get(keyword, None)
        if value is None:
            return self.rpc_find_node(sender, keyword)
        ret.extend(value)
        return ret

    def rpc_inv(self, sender, *serlialized_invs):
        self.addToRouter(sender)
        ret = []
        for inv in serlialized_invs:
            try:
                i = objects.Inv()
                i.ParseFromString(inv)
                if self.storage.getSpecific(i.keyword, i.valueKey) is None:
                    ret.append(inv)
            except Exception:
                pass
        return ret

    def rpc_values(self, sender, *serialized_values):
        self.addToRouter(sender)
        for val in serialized_values:
            try:
                v = objects.Value()
                v.ParseFromString(val)
                self.storage[v.keyword] = (v.valueKey, v.serializedData,
                                           int(v.ttl))
            except Exception:
                pass
        return ["True"]

    def callFindNode(self, nodeToAsk, nodeToFind):
        d = self.find_node(nodeToAsk, nodeToFind.id)
        return d.addCallback(self.handleCallResponse, nodeToAsk)

    def callFindValue(self, nodeToAsk, nodeToFind):
        d = self.find_value(nodeToAsk, nodeToFind.id)
        return d.addCallback(self.handleCallResponse, nodeToAsk)

    def callPing(self, nodeToAsk):
        d = self.ping(nodeToAsk)
        return d.addCallback(self.handleCallResponse, nodeToAsk)

    def callStore(self, nodeToAsk, keyword, key, value, ttl):
        d = self.store(nodeToAsk, keyword, key, value, str(int(round(ttl))))
        return d.addCallback(self.handleCallResponse, nodeToAsk)

    def callDelete(self, nodeToAsk, keyword, key, signature):
        d = self.delete(nodeToAsk, keyword, key, signature)
        return d.addCallback(self.handleCallResponse, nodeToAsk)

    def callInv(self, nodeToAsk, serlialized_inv_list):
        d = self.inv(nodeToAsk, *serlialized_inv_list)
        return d.addCallback(self.handleCallResponse, nodeToAsk)

    def callValues(self, nodeToAsk, serlialized_values_list):
        d = self.values(nodeToAsk, *serlialized_values_list)
        return d.addCallback(self.handleCallResponse, nodeToAsk)

    def transferKeyValues(self, node):
        """
        Given a new node, send it all the keys/values it should be storing.

        @param node: A new node that just joined (or that we just found out
        about).

        Process:
        For each key in storage, get k closest nodes.  If newnode is closer
        than the furtherst in that list, and the node for this server
        is closer than the closest in that list, then store the key/value
        on the new node (per section 2.5 of the paper)
        """
        def send_values(inv_list):
            values = []
            if inv_list[0]:
                for requested_inv in inv_list[1]:
                    try:
                        i = objects.Inv()
                        i.ParseFromString(requested_inv)
                        value = self.storage.getSpecific(i.keyword, i.valueKey)
                        if value is not None:
                            v = objects.Value()
                            v.keyword = i.keyword
                            v.valueKey = i.valueKey
                            v.serializedData = value
                            v.ttl = int(
                                round(
                                    self.storage.get_ttl(
                                        i.keyword, i.valueKey)))
                            values.append(v.SerializeToString())
                    except Exception:
                        pass
                if len(values) > 0:
                    self.callValues(node, values)

        inv = []
        for keyword in self.storage.iterkeys():
            keynode = Node(keyword)
            neighbors = self.router.findNeighbors(keynode, exclude=node)
            if len(neighbors) > 0:
                newNodeClose = node.distanceTo(
                    keynode) < neighbors[-1].distanceTo(keynode)
                thisNodeClosest = self.sourceNode.distanceTo(
                    keynode) < neighbors[0].distanceTo(keynode)
            if len(neighbors) == 0 \
                    or (newNodeClose and thisNodeClosest) \
                    or (thisNodeClosest and len(neighbors) < self.ksize):
                # pylint: disable=W0612
                for k, v in self.storage.iteritems(keyword):
                    i = objects.Inv()
                    i.keyword = keyword
                    i.valueKey = k
                    inv.append(i.SerializeToString())
        if len(inv) > 0:
            self.callInv(node, inv).addCallback(send_values)

    def handleCallResponse(self, result, node):
        """
        If we get a response, add the node to the routing table.  If
        we get no response, make sure it's removed from the routing table.
        """
        if result[0]:
            if self.isNewConnection(node):
                self.log.debug(
                    "call response from new node, transferring key/values")
                reactor.callLater(1, self.transferKeyValues, node)
            self.router.addContact(node)
        else:
            self.log.debug("no response from %s, removing from router" % node)
            self.router.removeContact(node)
        return result

    def addToRouter(self, node):
        """
        Called by rpc_ functions when a node sends them a request.
        We add the node to our router and transfer our stored values
        if they are new and within our neighborhood.
        """
        if self.isNewConnection(node):
            self.log.debug("found a new node, transferring key/values")
            reactor.callLater(1, self.transferKeyValues, node)
        self.router.addContact(node)

    def isNewConnection(self, node):
        if (node.ip, node.port) in self.multiplexer:
            return self.multiplexer[(
                node.ip, node.port)].handler.check_new_connection()
        else:
            return False

    def __iter__(self):
        return iter(self.handled_commands)
예제 #14
0
 def __init__(self, sourceID, ksize=20):
     self.router = RoutingTable(self, ksize, Node(sourceID))
     self.storage = {}
     self.sourceID = sourceID
예제 #15
0
class FakeProtocol(object):
    def __init__(self, sourceID, ksize=20):
        self.router = RoutingTable(self, ksize, Node(sourceID))
        self.storage = {}
        self.sourceID = sourceID

    def getRefreshIDs(self):
        """
        Get ids to search for to keep old buckets up to date.
        """
        ids = []
        for bucket in self.router.getLonelyBuckets():
            ids.append(random.randint(*bucket.range))
        return ids

    def rpc_ping(self, sender, nodeid):
        source = Node(nodeid, sender[0], sender[1])
        self.router.addContact(source)
        return self.sourceID

    def rpc_store(self, sender, nodeid, key, value):
        source = Node(nodeid, sender[0], sender[1])
        self.router.addContact(source)
        self.log.debug("got a store request from %s, storing value" %
                       str(sender))
        self.storage[key] = value

    def rpc_find_node(self, sender, nodeid, key):
        self.log.info("finding neighbors of %i in local table" %
                      long(nodeid.encode('hex'), 16))
        source = Node(nodeid, sender[0], sender[1])
        self.router.addContact(source)
        node = Node(key)
        return map(tuple, self.router.findNeighbors(node, exclude=source))

    def rpc_find_value(self, sender, nodeid, key):
        source = Node(nodeid, sender[0], sender[1])
        self.router.addContact(source)
        value = self.storage.get(key, None)
        if value is None:
            return self.rpc_find_node(sender, nodeid, key)
        return {'value': value}

    def callFindNode(self, nodeToAsk, nodeToFind):
        address = (nodeToAsk.ip, nodeToAsk.port)
        d = self.find_node(address, self.sourceID, nodeToFind.id)
        return d.addCallback(self.handleCallResponse, nodeToAsk)

    def callFindValue(self, nodeToAsk, nodeToFind):
        address = (nodeToAsk.ip, nodeToAsk.port)
        d = self.find_value(address, self.sourceID, nodeToFind.id)
        return d.addCallback(self.handleCallResponse, nodeToAsk)

    def callPing(self, nodeToAsk):
        address = (nodeToAsk.ip, nodeToAsk.port)
        d = self.ping(address, self.sourceID)
        return d.addCallback(self.handleCallResponse, nodeToAsk)

    def callStore(self, nodeToAsk, key, value):
        address = (nodeToAsk.ip, nodeToAsk.port)
        d = self.store(address, self.sourceID, key, value)
        return d.addCallback(self.handleCallResponse, nodeToAsk)

    def handleCallResponse(self, result, node):
        """
        If we get a response, add the node to the routing table.  If
        we get no response, make sure it's removed from the routing table.
        """
        if result[0]:
            self.log.info("got response from %s, adding to router" % node)
            self.router.addContact(node)
        else:
            self.log.debug("no response from %s, removing from router" % node)
            self.router.removeContact(node)
        return result
예제 #16
0
class KademliaProtocol(RPCProtocol):
    implements(MessageProcessor)

    def __init__(self, sourceNode, storage, ksize):
        self.ksize = ksize
        self.router = RoutingTable(self, ksize, sourceNode)
        self.storage = storage
        self.sourceNode = sourceNode
        self.multiplexer = None
        self.log = Logger(system=self)
        self.handled_commands = [
            PING, STUN, STORE, DELETE, FIND_NODE, FIND_VALUE, HOLE_PUNCH
        ]
        RPCProtocol.__init__(self, sourceNode.getProto(), self.router)

    def connect_multiplexer(self, multiplexer):
        self.multiplexer = multiplexer

    def getRefreshIDs(self):
        """
        Get ids to search for to keep old buckets up to date.
        """
        ids = []
        for bucket in self.router.getLonelyBuckets():
            ids.append(random.randint(*bucket.range))
        return ids

    def rpc_stun(self, sender):
        self.addToRouter(sender)
        return [sender.ip, str(sender.port)]

    def rpc_ping(self, sender):
        self.addToRouter(sender)
        return [self.sourceNode.getProto().SerializeToString()]

    def rpc_store(self, sender, keyword, key, value):
        self.addToRouter(sender)
        self.log.debug("got a store request from %s, storing value" %
                       str(sender))
        if len(keyword) == 20 and len(key) <= 33 and len(value) <= 1800:
            self.storage[keyword] = (key, value)
            return ["True"]
        else:
            return ["False"]

    def rpc_delete(self, sender, keyword, key, signature):
        self.addToRouter(sender)
        value = self.storage.getSpecific(keyword, key)
        if value is not None:
            # Try to delete a message from the dht
            if keyword == digest(sender.id):
                try:
                    verify_key = nacl.signing.VerifyKey(
                        sender.signed_pubkey[64:])
                    verify_key.verify(key, signature)
                    self.storage.delete(keyword, key)
                    return ["True"]
                except Exception:
                    return ["False"]
            # Or try to delete a pointer
            else:
                try:
                    node = objects.Node()
                    node.ParseFromString(value)
                    pubkey = node.signedPublicKey[64:]
                    try:
                        verify_key = nacl.signing.VerifyKey(pubkey)
                        verify_key.verify(signature + key)
                        self.storage.delete(keyword, key)
                        return ["True"]
                    except Exception:
                        return ["False"]
                except Exception:
                    pass
        return ["False"]

    def rpc_find_node(self, sender, key):
        self.log.info("finding neighbors of %s in local table" %
                      key.encode('hex'))
        self.addToRouter(sender)
        node = Node(key)
        nodeList = self.router.findNeighbors(node, exclude=sender)
        ret = []
        for n in nodeList:
            ret.append(n.getProto().SerializeToString())
        return ret

    def rpc_find_value(self, sender, key):
        self.addToRouter(sender)
        ret = ["value"]
        value = self.storage.get(key, None)
        if value is None:
            return self.rpc_find_node(sender, key)
        ret.extend(value)
        return ret

    def callFindNode(self, nodeToAsk, nodeToFind):
        address = (nodeToAsk.ip, nodeToAsk.port)
        d = self.find_node(address, nodeToFind.id)
        return d.addCallback(self.handleCallResponse, nodeToAsk)

    def callFindValue(self, nodeToAsk, nodeToFind):
        address = (nodeToAsk.ip, nodeToAsk.port)
        d = self.find_value(address, nodeToFind.id)
        return d.addCallback(self.handleCallResponse, nodeToAsk)

    def callPing(self, nodeToAsk):
        address = (nodeToAsk.ip, nodeToAsk.port)
        d = self.ping(address)
        return d.addCallback(self.handleCallResponse, nodeToAsk)

    def callStore(self, nodeToAsk, keyword, key, value):
        address = (nodeToAsk.ip, nodeToAsk.port)
        d = self.store(address, keyword, key, value)
        return d.addCallback(self.handleCallResponse, nodeToAsk)

    def callDelete(self, nodeToAsk, keyword, key, signature):
        address = (nodeToAsk.ip, nodeToAsk.port)
        d = self.delete(address, keyword, key, signature)
        return d.addCallback(self.handleCallResponse, nodeToAsk)

    def transferKeyValues(self, node):
        """
        Given a new node, send it all the keys/values it should be storing.

        @param node: A new node that just joined (or that we just found out
        about).

        Process:
        For each key in storage, get k closest nodes.  If newnode is closer
        than the furtherst in that list, and the node for this server
        is closer than the closest in that list, then store the key/value
        on the new node (per section 2.5 of the paper)
        """
        ds = []
        for keyword in self.storage.iterkeys():
            keynode = Node(keyword)
            neighbors = self.router.findNeighbors(keynode, exclude=node)
            if len(neighbors) > 0:
                newNodeClose = node.distanceTo(
                    keynode) < neighbors[-1].distanceTo(keynode)
                thisNodeClosest = self.sourceNode.distanceTo(
                    keynode) < neighbors[0].distanceTo(keynode)
            if len(neighbors) == 0 \
                    or (newNodeClose and thisNodeClosest) \
                    or (thisNodeClosest and len(neighbors) < self.ksize):
                for k, v in self.storage.iteritems(keyword):
                    ds.append(self.callStore(node, keyword, k, v))
        return defer.gatherResults(ds)

    def handleCallResponse(self, result, node):
        """
        If we get a response, add the node to the routing table.  If
        we get no response, make sure it's removed from the routing table.
        """
        if result[0]:
            if self.router.isNewNode(node):
                self.transferKeyValues(node)
            self.log.info("got response from %s, adding to router" % node)
            self.router.addContact(node)
        else:
            self.log.debug("no response from %s, removing from router" % node)
            self.router.removeContact(node)
        return result

    def addToRouter(self, node):
        """
        Called by rpc_ functions when a node sends them a request.
        We add the node to our router and transfer our stored values
        if they are new and within our neighborhood.
        """
        if self.router.isNewNode(node):
            self.log.debug("Found a new node, transferring key/values")
            self.transferKeyValues(node)
        self.router.addContact(node)

    def __iter__(self):
        return iter(self.handled_commands)
예제 #17
0
class DHTProtocol(dht_grpc.DHTServicer):
    # fmt:off
    node_id: DHTID
    port: int
    bucket_size: int
    num_replicas: int
    wait_timeout: float
    node_info: dht_pb2.NodeInfo
    channel_options: Optional[Sequence[Tuple[str, Any]]]
    server: grpc.experimental.aio.Server
    storage: LocalStorage
    cache: LocalStorage
    routing_table: RoutingTable
    rpc_semaphore: asyncio.Semaphore
    # fmt:on

    @classmethod
    async def create(cls,
                     node_id: DHTID,
                     bucket_size: int,
                     depth_modulo: int,
                     num_replicas: int,
                     wait_timeout: float,
                     parallel_rpc: Optional[int] = None,
                     cache_size: Optional[int] = None,
                     listen=True,
                     listen_on='0.0.0.0:*',
                     channel_options: Optional[Sequence[Tuple[str,
                                                              Any]]] = None,
                     **kwargs) -> DHTProtocol:
        """
        A protocol that allows DHT nodes to request keys/neighbors from other DHT nodes.
        As a side-effect, DHTProtocol also maintains a routing table as described in
        https://pdos.csail.mit.edu/~petar/papers/maymounkov-kademlia-lncs.pdf

        See DHTNode (node.py) for a more detailed description.

        :note: the rpc_* methods defined in this class will be automatically exposed to other DHT nodes,
         for instance, def rpc_ping can be called as protocol.call_ping(endpoint, dht_id) from a remote machine
         Only the call_* methods are meant to be called publicly, e.g. from DHTNode
         Read more: https://github.com/bmuller/rpcudp/tree/master/rpcudp
        """
        self = cls(_initialized_with_create=True)
        self.node_id, self.bucket_size, self.num_replicas = node_id, bucket_size, num_replicas
        self.wait_timeout, self.channel_options = wait_timeout, channel_options
        self.storage, self.cache = LocalStorage(), LocalStorage(
            maxsize=cache_size)
        self.routing_table = RoutingTable(node_id, bucket_size, depth_modulo)
        self.rpc_semaphore = asyncio.Semaphore(
            parallel_rpc if parallel_rpc is not None else float('inf'))

        if listen:  # set up server to process incoming rpc requests
            grpc.experimental.aio.init_grpc_aio()
            self.server = grpc.experimental.aio.server(**kwargs)
            dht_grpc.add_DHTServicer_to_server(self, self.server)

            found_port = self.server.add_insecure_port(listen_on)
            assert found_port != 0, f"Failed to listen to {listen_on}"
            self.node_info = dht_pb2.NodeInfo(node_id=node_id.to_bytes(),
                                              rpc_port=found_port)
            self.port = found_port
            await self.server.start()
        else:  # not listening to incoming requests, client-only mode
            # note: use empty node_info so peers wont add you to their routing tables
            self.node_info, self.server, self.port = dht_pb2.NodeInfo(
            ), None, None
            if listen_on != '0.0.0.0:*' or len(kwargs) != 0:
                warn(
                    f"DHTProtocol has no server (due to listen=False), listen_on"
                    f"and kwargs have no effect (unused kwargs: {kwargs})")
        return self

    def __init__(self, *, _initialized_with_create=False):
        """ Internal init method. Please use DHTProtocol.create coroutine to spawn new protocol instances """
        assert _initialized_with_create, " Please use DHTProtocol.create coroutine to spawn new protocol instances "
        super().__init__()

    async def shutdown(self, timeout=None):
        """ Process existing requests, close all connections and stop the server """
        if self.server:
            await self.server.stop(timeout)
        else:
            warn(
                "DHTProtocol has no server (due to listen=False), it doesn't need to be shut down"
            )

    def _get(self, peer: Endpoint) -> dht_grpc.DHTStub:
        """ get a DHTStub that sends requests to a given peer """
        channel = grpc.experimental.aio.insecure_channel(
            peer, options=self.channel_options)
        return dht_grpc.DHTStub(channel)

    async def call_ping(self, peer: Endpoint) -> Optional[DHTID]:
        """
        Get peer's node id and add him to the routing table. If peer doesn't respond, return None
        :param peer: string network address, e.g. 123.123.123.123:1337 or [2a21:6с8:b192:2105]:8888
        :note: if DHTProtocol was created with listen=True, also request peer to add you to his routing table

        :return: node's DHTID, if peer responded and decided to send his node_id
        """
        try:
            async with self.rpc_semaphore:
                peer_info = await self._get(peer).rpc_ping(
                    self.node_info, timeout=self.wait_timeout)
        except grpc.experimental.aio.AioRpcError as error:
            logger.warning(
                f"DHTProtocol failed to ping {peer}: {error.code()}")
            peer_info = None
        responded = bool(peer_info and peer_info.node_id)
        peer_id = DHTID.from_bytes(peer_info.node_id) if responded else None
        asyncio.create_task(
            self.update_routing_table(peer_id, peer, responded=responded))
        return peer_id

    async def rpc_ping(self, peer_info: dht_pb2.NodeInfo,
                       context: grpc.ServicerContext):
        """ Some node wants us to add it to our routing table. """
        if peer_info.node_id and peer_info.rpc_port:
            sender_id = DHTID.from_bytes(peer_info.node_id)
            rpc_endpoint = replace_port(context.peer(),
                                        new_port=peer_info.rpc_port)
            asyncio.create_task(
                self.update_routing_table(sender_id, rpc_endpoint))
        return self.node_info

    async def call_store(
        self,
        peer: Endpoint,
        keys: Sequence[DHTID],
        values: Sequence[BinaryDHTValue],
        expiration_time: Union[DHTExpiration, Sequence[DHTExpiration]],
        in_cache: Optional[Union[bool,
                                 Sequence[bool]]] = None) -> Sequence[bool]:
        """
        Ask a recipient to store several (key, value : expiration_time) items or update their older value

        :param peer: request this peer to store the data
        :param keys: a list of N keys digested by DHTID.generate(source=some_dict_key)
        :param values: a list of N serialized values (bytes) for each respective key
        :param expiration_time: a list of N expiration timestamps for each respective key-value pair (see get_dht_time())
        :param in_cache: a list of booleans, True = store i-th key in cache, value = store i-th key locally
        :note: the difference between storing normally and in cache is that normal storage is guaranteed to be stored
         until expiration time (best-effort), whereas cached storage can be evicted early due to limited cache size

        :return: list of [True / False] True = stored, False = failed (found newer value or no response)
         if peer did not respond (e.g. due to timeout or congestion), returns None
        """
        if isinstance(expiration_time, DHTExpiration):
            expiration_time = [expiration_time] * len(keys)
        in_cache = in_cache if in_cache is not None else [False] * len(
            keys)  # default value (None)
        in_cache = [in_cache] * len(keys) if isinstance(
            in_cache, bool) else in_cache  # single bool
        keys, values, expiration_time, in_cache = map(
            list, [keys, values, expiration_time, in_cache])
        assert len(keys) == len(values) == len(expiration_time) == len(
            in_cache), "Data is not aligned"
        store_request = dht_pb2.StoreRequest(keys=list(
            map(DHTID.to_bytes, keys)),
                                             values=values,
                                             expiration_time=expiration_time,
                                             in_cache=in_cache,
                                             peer=self.node_info)
        try:
            async with self.rpc_semaphore:
                response = await self._get(peer).rpc_store(
                    store_request, timeout=self.wait_timeout)
            if response.peer and response.peer.node_id:
                peer_id = DHTID.from_bytes(response.peer.node_id)
                asyncio.create_task(
                    self.update_routing_table(peer_id, peer, responded=True))
            return response.store_ok
        except grpc.experimental.aio.AioRpcError as error:
            logger.warning(
                f"DHTProtocol failed to store at {peer}: {error.code()}")
            asyncio.create_task(
                self.update_routing_table(
                    self.routing_table.get(endpoint=peer),
                    peer,
                    responded=False))
            return [False] * len(keys)

    async def rpc_store(
            self, request: dht_pb2.StoreRequest,
            context: grpc.ServicerContext) -> dht_pb2.StoreResponse:
        """ Some node wants us to store this (key, value) pair """
        if request.peer:  # if requested, add peer to the routing table
            asyncio.create_task(self.rpc_ping(request.peer, context))
        assert len(request.keys) == len(request.values) == len(
            request.expiration_time) == len(request.in_cache)
        response = dht_pb2.StoreResponse(store_ok=[], peer=self.node_info)
        for key_bytes, value_bytes, expiration_time, in_cache in zip(
                request.keys, request.values, request.expiration_time,
                request.in_cache):
            local_memory = self.cache if in_cache else self.storage
            response.store_ok.append(
                local_memory.store(DHTID.from_bytes(key_bytes), value_bytes,
                                   expiration_time))
        return response

    async def call_find(self, peer: Endpoint, keys: Collection[DHTID]) -> \
            Optional[Dict[DHTID, Tuple[Optional[BinaryDHTValue], Optional[DHTExpiration], Dict[DHTID, Endpoint]]]]:
        """
        Request keys from a peer. For each key, look for its (value, expiration time) locally and
         k additional peers that are most likely to have this key (ranked by XOR distance)

        :returns: A dict key => Tuple[optional value, optional expiration time, nearest neighbors]
         value: value stored by the recipient with that key, or None if peer doesn't have this value
         expiration time: expiration time of the returned value, None if no value was found
         neighbors: a dictionary[node_id : endpoint] containing nearest neighbors from peer's routing table
         If peer didn't respond, returns None
        """
        keys = list(keys)
        find_request = dht_pb2.FindRequest(keys=list(map(DHTID.to_bytes,
                                                         keys)),
                                           peer=self.node_info)
        try:
            async with self.rpc_semaphore:
                response = await self._get(peer).rpc_find(
                    find_request, timeout=self.wait_timeout)
            if response.peer and response.peer.node_id:
                peer_id = DHTID.from_bytes(response.peer.node_id)
                asyncio.create_task(
                    self.update_routing_table(peer_id, peer, responded=True))
            assert len(response.values) == len(response.expiration_time) == len(response.nearest) == len(keys), \
                "DHTProtocol: response is not aligned with keys and/or expiration times"

            output = {}  # unpack data without special NOT_FOUND_* values
            for key, value, expiration_time, nearest in zip(
                    keys, response.values, response.expiration_time,
                    response.nearest):
                value = value if value != _NOT_FOUND_VALUE else None
                expiration_time = expiration_time if expiration_time != _NOT_FOUND_EXPIRATION else None
                nearest = dict(
                    zip(map(DHTID.from_bytes, nearest.node_ids),
                        nearest.endpoints))
                output[key] = (value, expiration_time, nearest)
            return output
        except grpc.experimental.aio.AioRpcError as error:
            logger.warning(
                f"DHTProtocol failed to find at {peer}: {error.code()}")
            asyncio.create_task(
                self.update_routing_table(
                    self.routing_table.get(endpoint=peer),
                    peer,
                    responded=False))

    async def rpc_find(self, request: dht_pb2.FindRequest,
                       context: grpc.ServicerContext) -> dht_pb2.FindResponse:
        """
        Someone wants to find keys in the DHT. For all keys that we have locally, return value and expiration
        Also return :bucket_size: nearest neighbors from our routing table for each key (whether or not we found value)
        """
        if request.peer:  # if requested, add peer to the routing table
            asyncio.create_task(self.rpc_ping(request.peer, context))

        response = dht_pb2.FindResponse(values=[],
                                        expiration_time=[],
                                        nearest=[],
                                        peer=self.node_info)
        for key_id in map(DHTID.from_bytes, request.keys):
            maybe_value, maybe_expiration_time = self.storage.get(key_id)
            cached_value, cached_expiration_time = self.cache.get(key_id)
            if (cached_expiration_time or
                    -float('inf')) > (maybe_expiration_time or -float('inf')):
                maybe_value, maybe_expiration_time = cached_value, cached_expiration_time

            nearest_neighbors = self.routing_table.get_nearest_neighbors(
                key_id,
                k=self.bucket_size,
                exclude=DHTID.from_bytes(request.peer.node_id))
            if nearest_neighbors:
                peer_ids, endpoints = zip(*nearest_neighbors)
            else:
                peer_ids, endpoints = [], []

            response.values.append(
                maybe_value if maybe_value is not None else _NOT_FOUND_VALUE)
            response.expiration_time.append(
                maybe_expiration_time
                if maybe_expiration_time else _NOT_FOUND_EXPIRATION)
            response.nearest.append(
                dht_pb2.Peers(node_ids=list(map(DHTID.to_bytes, peer_ids)),
                              endpoints=endpoints))
        return response

    async def update_routing_table(self,
                                   node_id: Optional[DHTID],
                                   peer_endpoint: Endpoint,
                                   responded=True):
        """
        This method is called on every incoming AND outgoing request to update the routing table

        :param peer_endpoint: sender endpoint for incoming requests, recipient endpoint for outgoing requests
        :param node_id: sender node id for incoming requests, recipient node id for outgoing requests
        :param responded: for outgoing requests, this indicated whether recipient responded or not.
          For incoming requests, this should always be True
        """
        node_id = node_id if node_id is not None else self.routing_table.get(
            endpoint=peer_endpoint)
        if responded:  # incoming request or outgoing request with response
            if node_id not in self.routing_table:
                # we just met a new node, maybe we know some values that it *should* store
                data_to_send: List[Tuple[DHTID, BinaryDHTValue,
                                         DHTExpiration]] = []
                for key, value, expiration_time in list(self.storage.items()):
                    neighbors = self.routing_table.get_nearest_neighbors(
                        key, self.num_replicas, exclude=self.node_id)
                    if neighbors:
                        nearest_distance = neighbors[0][0].xor_distance(key)
                        farthest_distance = neighbors[-1][0].xor_distance(key)
                        new_node_should_store = node_id.xor_distance(
                            key) < farthest_distance
                        this_node_is_responsible = self.node_id.xor_distance(
                            key) < nearest_distance
                    if not neighbors or (new_node_should_store
                                         and this_node_is_responsible):
                        data_to_send.append((key, value, expiration_time))
                if data_to_send:
                    asyncio.create_task(
                        self.call_store(peer_endpoint,
                                        *zip(*data_to_send),
                                        in_cache=False))

            maybe_node_to_ping = self.routing_table.add_or_update_node(
                node_id, peer_endpoint)
            if maybe_node_to_ping is not None:
                # we couldn't add new node because the table was full. Check if existing peers are alive (Section 2.2)
                # ping one least-recently updated peer: if it won't respond, remove it from the table, else update it
                asyncio.create_task(self.call_ping(maybe_node_to_ping[1])
                                    )  # [1]-th element is that node's endpoint

        else:  # we sent outgoing request and peer did not respond
            if node_id is not None and node_id in self.routing_table:
                del self.routing_table[node_id]
예제 #18
0
class KademliaProtocol(RPCProtocol):
    implements(MessageProcessor)

    def __init__(self, sourceNode, storage, ksize, database, signing_key):
        self.ksize = ksize
        self.router = RoutingTable(self, ksize, sourceNode)
        self.storage = storage
        self.sourceNode = sourceNode
        self.multiplexer = None
        self.db = database
        self.signing_key = signing_key
        self.log = Logger(system=self)
        self.handled_commands = [PING, STUN, STORE, DELETE, FIND_NODE, FIND_VALUE, HOLE_PUNCH, INV, VALUES]
        self.recent_transfers = set()
        RPCProtocol.__init__(self, sourceNode, self.router)

    def connect_multiplexer(self, multiplexer):
        self.multiplexer = multiplexer

    def getRefreshIDs(self):
        """
        Get ids to search for to keep old buckets up to date.
        """
        ids = []
        for bucket in self.router.getLonelyBuckets():
            ids.append(random.randint(*bucket.range))
        return ids

    def rpc_stun(self, sender):
        self.addToRouter(sender)
        return [sender.ip, str(sender.port)]

    def rpc_ping(self, sender):
        self.addToRouter(sender)
        return [self.sourceNode.getProto().SerializeToString()]

    def rpc_store(self, sender, keyword, key, value, ttl):
        self.addToRouter(sender)
        self.log.debug("got a store request from %s, storing value" % str(sender))
        if len(keyword) == 20 and len(key) <= 33 and len(value) <= 2100 and int(ttl) <= 604800:
            self.storage[keyword] = (key, value, int(ttl))
            return ["True"]
        else:
            return ["False"]

    def rpc_delete(self, sender, keyword, key, signature):
        self.addToRouter(sender)
        value = self.storage.getSpecific(keyword, key)
        if value is not None:
            # Try to delete a message from the dht
            if keyword == digest(sender.id):
                try:
                    verify_key = nacl.signing.VerifyKey(sender.pubkey)
                    verify_key.verify(key, signature)
                    self.storage.delete(keyword, key)
                    return ["True"]
                except Exception:
                    return ["False"]
            # Or try to delete a pointer
            else:
                try:
                    node = objects.Node()
                    node.ParseFromString(value)
                    pubkey = node.publicKey
                    try:
                        verify_key = nacl.signing.VerifyKey(pubkey)
                        verify_key.verify(key, signature)
                        self.storage.delete(keyword, key)
                        return ["True"]
                    except Exception:
                        return ["False"]
                except Exception:
                    pass
        return ["False"]

    def rpc_find_node(self, sender, key):
        self.log.debug("finding neighbors of %s in local table" % key.encode('hex'))
        self.addToRouter(sender)
        node = Node(key)
        nodeList = self.router.findNeighbors(node, exclude=sender)
        ret = []
        if self.sourceNode.id == key:
            ret.append(self.sourceNode.getProto().SerializeToString())
        for n in nodeList:
            ret.append(n.getProto().SerializeToString())
        return ret

    def rpc_find_value(self, sender, keyword):
        self.addToRouter(sender)
        ret = ["value"]
        value = self.storage.get(keyword, None)
        if value is None:
            return self.rpc_find_node(sender, keyword)
        ret.extend(value)
        return ret

    def rpc_inv(self, sender, *serlialized_invs):
        self.addToRouter(sender)
        ret = []
        for inv in serlialized_invs:
            try:
                i = objects.Inv()
                i.ParseFromString(inv)
                if self.storage.getSpecific(i.keyword, i.valueKey) is None:
                    ret.append(inv)
            except Exception:
                pass
        return ret

    def rpc_values(self, sender, *serialized_values):
        self.addToRouter(sender)
        for val in serialized_values[:100]:
            try:
                v = objects.Value()
                v.ParseFromString(val)
                self.storage[v.keyword] = (v.valueKey, v.serializedData, int(v.ttl))
            except Exception:
                pass
        return ["True"]

    def callFindNode(self, nodeToAsk, nodeToFind):
        d = self.find_node(nodeToAsk, nodeToFind.id)
        return d.addCallback(self.handleCallResponse, nodeToAsk)

    def callFindValue(self, nodeToAsk, nodeToFind):
        d = self.find_value(nodeToAsk, nodeToFind.id)
        return d.addCallback(self.handleCallResponse, nodeToAsk)

    def callPing(self, nodeToAsk):
        d = self.ping(nodeToAsk)
        return d.addCallback(self.handleCallResponse, nodeToAsk)

    def callStore(self, nodeToAsk, keyword, key, value, ttl):
        d = self.store(nodeToAsk, keyword, key, value, str(int(round(ttl))))
        return d.addCallback(self.handleCallResponse, nodeToAsk)

    def callDelete(self, nodeToAsk, keyword, key, signature):
        d = self.delete(nodeToAsk, keyword, key, signature)
        return d.addCallback(self.handleCallResponse, nodeToAsk)

    def callInv(self, nodeToAsk, serlialized_inv_list):
        d = self.inv(nodeToAsk, *serlialized_inv_list)
        return d.addCallback(self.handleCallResponse, nodeToAsk)

    def callValues(self, nodeToAsk, serlialized_values_list):
        d = self.values(nodeToAsk, *serlialized_values_list)
        return d.addCallback(self.handleCallResponse, nodeToAsk)

    def transferKeyValues(self, node):
        """
        Given a new node, send it all the keys/values it should be storing.

        @param node: A new node that just joined (or that we just found out
        about).

        Process:
        For each key in storage, get k closest nodes.  If newnode is closer
        than the furtherst in that list, and the node for this server
        is closer than the closest in that list, then store the key/value
        on the new node (per section 2.5 of the paper)
        """
        def send_values(inv_list):
            values = []
            if inv_list[0]:
                for requested_inv in inv_list[1]:
                    try:
                        i = objects.Inv()
                        i.ParseFromString(requested_inv)
                        value = self.storage.getSpecific(i.keyword, i.valueKey)
                        if value is not None:
                            v = objects.Value()
                            v.keyword = i.keyword
                            v.valueKey = i.valueKey
                            v.serializedData = value
                            v.ttl = int(round(self.storage.get_ttl(i.keyword, i.valueKey)))
                            values.append(v.SerializeToString())
                    except Exception:
                        pass
                if len(values) > 0:
                    self.callValues(node, values)

        inv = []
        for keyword in self.storage.iterkeys():
            keyword = keyword[0].decode("hex")
            keynode = Node(keyword)
            neighbors = self.router.findNeighbors(keynode, exclude=node)
            if len(neighbors) > 0:
                newNodeClose = node.distanceTo(keynode) < neighbors[-1].distanceTo(keynode)
                thisNodeClosest = self.sourceNode.distanceTo(keynode) < neighbors[0].distanceTo(keynode)
            if len(neighbors) == 0 \
                    or (newNodeClose and thisNodeClosest) \
                    or (thisNodeClosest and len(neighbors) < self.ksize):
                # pylint: disable=W0612
                for k, v in self.storage.iteritems(keyword):
                    i = objects.Inv()
                    i.keyword = keyword
                    i.valueKey = k
                    inv.append(i.SerializeToString())
        if len(inv) > 100:
            random.shuffle(inv)
        if len(inv) > 0:
            self.callInv(node, inv[:100]).addCallback(send_values)

    def handleCallResponse(self, result, node):
        """
        If we get a response, add the node to the routing table.  If
        we get no response, make sure it's removed from the routing table.
        """
        if result[0]:
            if self.isNewConnection(node) and node.id not in self.recent_transfers:
                if len(self.recent_transfers) == 10:
                    self.recent_transfers.pop()
                self.recent_transfers.add(node.id)
                self.log.debug("call response from new node, transferring key/values")
                reactor.callLater(1, self.transferKeyValues, node)
            self.router.addContact(node)
        else:
            self.log.debug("no response from %s, removing from router" % node)
            self.router.removeContact(node)
        return result

    def addToRouter(self, node):
        """
        Called by rpc_ functions when a node sends them a request.
        We add the node to our router and transfer our stored values
        if they are new and within our neighborhood.
        """
        if self.isNewConnection(node) and node.id not in self.recent_transfers:
            if len(self.recent_transfers) == 10:
                self.recent_transfers.pop()
            self.recent_transfers.add(node.id)
            self.log.debug("found a new node, transferring key/values")
            reactor.callLater(1, self.transferKeyValues, node)
        self.router.addContact(node)

    def isNewConnection(self, node):
        if (node.ip, node.port) in self.multiplexer:
            return self.multiplexer[(node.ip, node.port)].handler.check_new_connection()
        else:
            return False

    def __iter__(self):
        return iter(self.handled_commands)
예제 #19
0
파일: utils.py 프로젝트: Renelvon/Network
class FakeProtocol(object):
    def __init__(self, sourceID, ksize=20):
        self.router = RoutingTable(self, ksize, Node(sourceID))
        self.storage = {}
        self.sourceID = sourceID

    def getRefreshIDs(self):
        """
        Get ids to search for to keep old buckets up to date.
        """
        ids = []
        for bucket in self.router.getLonelyBuckets():
            ids.append(random.randint(*bucket.range))
        return ids

    def rpc_ping(self, sender, nodeid):
        source = Node(nodeid, sender[0], sender[1])
        self.router.addContact(source)
        return self.sourceID

    def rpc_store(self, sender, nodeid, key, value):
        source = Node(nodeid, sender[0], sender[1])
        self.router.addContact(source)
        self.log.debug("got a store request from %s, storing value" % str(sender))
        self.storage[key] = value

    def rpc_find_node(self, sender, nodeid, key):
        self.log.info("finding neighbors of %i in local table" % long(nodeid.encode('hex'), 16))
        source = Node(nodeid, sender[0], sender[1])
        self.router.addContact(source)
        node = Node(key)
        return map(tuple, self.router.findNeighbors(node, exclude=source))

    def rpc_find_value(self, sender, nodeid, key):
        source = Node(nodeid, sender[0], sender[1])
        self.router.addContact(source)
        value = self.storage.get(key, None)
        if value is None:
            return self.rpc_find_node(sender, nodeid, key)
        return {'value': value}

    def callFindNode(self, nodeToAsk, nodeToFind):
        address = (nodeToAsk.ip, nodeToAsk.port)
        d = self.find_node(address, self.sourceID, nodeToFind.id)
        return d.addCallback(self.handleCallResponse, nodeToAsk)

    def callFindValue(self, nodeToAsk, nodeToFind):
        address = (nodeToAsk.ip, nodeToAsk.port)
        d = self.find_value(address, self.sourceID, nodeToFind.id)
        return d.addCallback(self.handleCallResponse, nodeToAsk)

    def callPing(self, nodeToAsk):
        address = (nodeToAsk.ip, nodeToAsk.port)
        d = self.ping(address, self.sourceID)
        return d.addCallback(self.handleCallResponse, nodeToAsk)

    def callStore(self, nodeToAsk, key, value):
        address = (nodeToAsk.ip, nodeToAsk.port)
        d = self.store(address, self.sourceID, key, value)
        return d.addCallback(self.handleCallResponse, nodeToAsk)

    def handleCallResponse(self, result, node):
        """
        If we get a response, add the node to the routing table.  If
        we get no response, make sure it's removed from the routing table.
        """
        if result[0]:
            self.log.info("got response from %s, adding to router" % node)
            self.router.addContact(node)
        else:
            self.log.debug("no response from %s, removing from router" % node)
            self.router.removeContact(node)
        return result
예제 #20
0
class KademliaProtocol(RPCProtocol):
    implements(MessageProcessor)

    def __init__(self, sourceNode, storage, ksize, database):
        self.ksize = ksize
        self.router = RoutingTable(self, ksize, sourceNode)
        self.storage = storage
        self.sourceNode = sourceNode
        self.multiplexer = None
        self.db = database
        self.log = Logger(system=self)
        self.handled_commands = [PING, STUN, STORE, DELETE, FIND_NODE, FIND_VALUE, HOLE_PUNCH]
        RPCProtocol.__init__(self, sourceNode.getProto(), self.router)

    def connect_multiplexer(self, multiplexer):
        self.multiplexer = multiplexer

    def getRefreshIDs(self):
        """
        Get ids to search for to keep old buckets up to date.
        """
        ids = []
        for bucket in self.router.getLonelyBuckets():
            ids.append(random.randint(*bucket.range))
        return ids

    def rpc_stun(self, sender):
        self.addToRouter(sender)
        return [sender.ip, str(sender.port)]

    def rpc_ping(self, sender):
        self.addToRouter(sender)
        return [self.sourceNode.getProto().SerializeToString()]

    def rpc_store(self, sender, keyword, key, value):
        self.addToRouter(sender)
        self.log.debug("got a store request from %s, storing value" % str(sender))
        if len(keyword) == 20 and len(key) <= 33 and len(value) <= 1800:
            self.storage[keyword] = (key, value)
            return ["True"]
        else:
            return ["False"]

    def rpc_delete(self, sender, keyword, key, signature):
        self.addToRouter(sender)
        value = self.storage.getSpecific(keyword, key)
        if value is not None:
            # Try to delete a message from the dht
            if keyword == digest(sender.id):
                try:
                    verify_key = nacl.signing.VerifyKey(sender.signed_pubkey[64:])
                    verify_key.verify(key, signature)
                    self.storage.delete(keyword, key)
                    return ["True"]
                except Exception:
                    return ["False"]
            # Or try to delete a pointer
            else:
                try:
                    node = objects.Node()
                    node.ParseFromString(value)
                    pubkey = node.signedPublicKey[64:]
                    try:
                        verify_key = nacl.signing.VerifyKey(pubkey)
                        verify_key.verify(signature + key)
                        self.storage.delete(keyword, key)
                        return ["True"]
                    except Exception:
                        return ["False"]
                except Exception:
                    pass
        return ["False"]

    def rpc_find_node(self, sender, key):
        self.log.info("finding neighbors of %s in local table" % key.encode('hex'))
        self.addToRouter(sender)
        node = Node(key)
        nodeList = self.router.findNeighbors(node, exclude=sender)
        ret = []
        for n in nodeList:
            ret.append(n.getProto().SerializeToString())
        return ret

    def rpc_find_value(self, sender, key):
        self.addToRouter(sender)
        ret = ["value"]
        value = self.storage.get(key, None)
        if value is None:
            return self.rpc_find_node(sender, key)
        ret.extend(value)
        return ret

    def callFindNode(self, nodeToAsk, nodeToFind):
        address = (nodeToAsk.ip, nodeToAsk.port)
        d = self.find_node(address, nodeToFind.id)
        return d.addCallback(self.handleCallResponse, nodeToAsk)

    def callFindValue(self, nodeToAsk, nodeToFind):
        address = (nodeToAsk.ip, nodeToAsk.port)
        d = self.find_value(address, nodeToFind.id)
        return d.addCallback(self.handleCallResponse, nodeToAsk)

    def callPing(self, nodeToAsk):
        address = (nodeToAsk.ip, nodeToAsk.port)
        d = self.ping(address)
        return d.addCallback(self.handleCallResponse, nodeToAsk)

    def callStore(self, nodeToAsk, keyword, key, value):
        address = (nodeToAsk.ip, nodeToAsk.port)
        d = self.store(address, keyword, key, value)
        return d.addCallback(self.handleCallResponse, nodeToAsk)

    def callDelete(self, nodeToAsk, keyword, key, signature):
        address = (nodeToAsk.ip, nodeToAsk.port)
        d = self.delete(address, keyword, key, signature)
        return d.addCallback(self.handleCallResponse, nodeToAsk)

    def transferKeyValues(self, node):
        """
        Given a new node, send it all the keys/values it should be storing.

        @param node: A new node that just joined (or that we just found out
        about).

        Process:
        For each key in storage, get k closest nodes.  If newnode is closer
        than the furtherst in that list, and the node for this server
        is closer than the closest in that list, then store the key/value
        on the new node (per section 2.5 of the paper)
        """
        ds = []
        for keyword in self.storage.iterkeys():
            keynode = Node(keyword)
            neighbors = self.router.findNeighbors(keynode, exclude=node)
            if len(neighbors) > 0:
                newNodeClose = node.distanceTo(keynode) < neighbors[-1].distanceTo(keynode)
                thisNodeClosest = self.sourceNode.distanceTo(keynode) < neighbors[0].distanceTo(keynode)
            if len(neighbors) == 0 \
                    or (newNodeClose and thisNodeClosest) \
                    or (thisNodeClosest and len(neighbors) < self.ksize):
                for k, v in self.storage.iteritems(keyword):
                    ds.append(self.callStore(node, keyword, k, v))
        return defer.gatherResults(ds)

    def handleCallResponse(self, result, node):
        """
        If we get a response, add the node to the routing table.  If
        we get no response, make sure it's removed from the routing table.
        """
        if result[0]:
            if self.router.isNewNode(node):
                self.transferKeyValues(node)
            self.log.info("got response from %s, adding to router" % node)
            self.router.addContact(node)
        else:
            self.log.debug("no response from %s, removing from router" % node)
            self.router.removeContact(node)
        return result

    def addToRouter(self, node):
        """
        Called by rpc_ functions when a node sends them a request.
        We add the node to our router and transfer our stored values
        if they are new and within our neighborhood.
        """
        if self.router.isNewNode(node):
            self.log.debug("Found a new node, transferring key/values")
            self.transferKeyValues(node)
        self.router.addContact(node)

    def __iter__(self):
        return iter(self.handled_commands)
예제 #21
0
파일: utils.py 프로젝트: Renelvon/Network
 def __init__(self, sourceID, ksize=20):
     self.router = RoutingTable(self, ksize, Node(sourceID))
     self.storage = {}
     self.sourceID = sourceID
예제 #22
0
 def setUp(self):
     self.node = Node(digest("test"), "127.0.0.1", 1234)
     self.router = RoutingTable(self, 20, self.node.id)
예제 #23
0
 def setUp(self):
     self.node = Node(digest("test"), "127.0.0.1", 1234)
     self.router = RoutingTable(self, 20, self.node.id)