/
network.py
218 lines (185 loc) · 7.81 KB
/
network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import random
import pickle
import asyncio
import logging
from node import Node
from utils import digest
from crawling import NodeSpiderCrawl
from protocol import KademliaProtocol
from crawling import ValueSpiderCrawl
from storage import ForgetfulStorage, AwesomeStorage
log = logging.getLogger(__name__)
class Server:
protocol_class = KademliaProtocol
def __init__(self, ksize=20, alpha=3, node_id=None, storage=None):
self.ksize = ksize
self.alpha = alpha
self.storage = (storage or ForgetfulStorage())
self.node = Node(node_id or digest(random.getrandbits(255)))
self.transport = None
self.protocol = None
self.refresh_loop = None
self.save_state_loop = None
def stop(self):
if self.transport is not None:
self.transport.close()
if self.refresh_loop:
self.refresh_loop.cancel()
if self.save_state_loop:
self.save_state_loop.cancel()
def _create_protocol(self):
return self.protocol_class(self.node, self.storage, self.ksize)
async def listen(self, port, interface='0.0.0.0'):
self.node = Node(self.node.id, interface, port)
loop = asyncio.get_event_loop()
listen = loop.create_datagram_endpoint(self._create_protocol,
local_addr=(interface, port))
log.info("Node %i listening on %s:%i",
self.node.long_id, interface, port)
self.transport, self.protocol = await listen
# finally, schedule refreshing table
self.refresh_table()
def refresh_table(self):
log.debug("Refreshing routing table")
asyncio.ensure_future(self._refresh_table())
loop = asyncio.get_event_loop()
self.refresh_loop = loop.call_later(3600, self.refresh_table)
async def _refresh_table(self):
results = []
for node_id in self.protocol.get_refresh_ids():
node = Node(node_id)
nearest = self.protocol.router.find_neighbors(node, self.alpha)
spider = NodeSpiderCrawl(self.protocol, node, nearest,
self.ksize, self.alpha)
results.append(spider.find())
await asyncio.gather(*results)
for dkey, value in self.storage.iter_older_than(3600):
await self.set_digest(dkey, value)
async def bootstrappable_neighbors(self):
neighbors = self.protocol.router.find_neighbors(self.node)
print(neighbors) # added by myselfs
return [tuple(n)[-2:] for n in neighbors]
async def bootstrap(self, addrs):
log.debug("Attempting to bootstrap node with %i initial contacts",
len(addrs))
cos = list(map(self.bootstrap_node, addrs))
gathered = await asyncio.gather(*cos)
nodes = [node for node in gathered if node is not None]
spider = NodeSpiderCrawl(self.protocol, self.node, nodes,
self.ksize, self.alpha)
return await spider.find()
async def bootstrap_node(self, addr):
result = await self.protocol.ping(addr, self.node.id)
return Node(result[1], addr[0], addr[1]) if result[0] else None
async def get(self, key, hash=False):
log.info("Looking up key %s", key)
dkey = key
if hash:
dkey = digest(key)
node = Node(dkey)
nearest = self.protocol.router.find_neighbors(node)
if not nearest:
log.warning("There are no known neighbors to get key %s", key)
return None
spider = ValueSpiderCrawl(self.protocol, node, nearest,
self.ksize, self.alpha)
return await spider.find()
async def delete(self, key, hash = True):
dkey = key
if hash:
dkey = digest(key)
"""
if self.storage.get(dkey) is not None:
# delete the key from here
self.storage.delete(dkey)
"""
# also delete the key from neighbors
node = Node(dkey)
nearest = self.protocol.router.find_neighbors(node)
if not nearest:
log.warning("There are no known neighbors to get key %s", key)
return None
spider = NodeSpiderCrawl(self.protocol, node, nearest, self.ksize, self.alpha)
nodes = await spider.find()
results = [self.protocol.call_delete(n, dkey) for n in nodes]
# return true only if at least one delete call succeeded
return any(await asyncio.gather(*results))
async def delete_tag(self, key, value):
dkey = digest(key)
node = Node(dkey)
nearest = self.protocol.router.find_neighbors(node)
if not nearest:
log.warning("There are no known neighbors to set key %s", key)
return None
spider = NodeSpiderCrawl(self.protocol, node, nearest,
self.ksize, self.alpha)
nodes = await spider.find()
log.info("setting '%s' on %s", dkey.hex(), list(map(str, nodes)))
results = [self.protocol.call_delete_tag(n, dkey, key, value) for n in nodes]
# return true only if at least one store call succeeded
return any(await asyncio.gather(*results))
async def set(self, key, name, value, hash = True):
if not check_dht_value_type(value):
raise TypeError(
"Value must be of type int, float, bool, str, or bytes"
)
log.info("setting '%s' = '%s' on network", key, value)
dkey = digest(key)
return await self.set_digest(dkey, key, name , value, hash)
async def set_digest(self, dkey, key, name, value, hash = True):
node = Node(dkey)
nearest = self.protocol.router.find_neighbors(node)
if not nearest:
log.warning("There are no known neighbors to set key %s",
dkey.hex())
return False
spider = NodeSpiderCrawl(self.protocol, node, nearest,
self.ksize, self.alpha)
nodes = await spider.find()
log.info("setting '%s' on %s", dkey.hex(), list(map(str, nodes)))
# if this node is close too, then store here as well
biggest = max([n.distance_to(node) for n in nodes])
if self.node.distance_to(node) < biggest:
self.storage.set(dkey, key, name , value, hash)
results = [self.protocol.call_store(n, dkey, key, name, value, hash) for n in nodes]
# return true only if at least one store call succeeded
return any(await asyncio.gather(*results))
def save_state(self, fname):
log.info("Saving state to %s", fname)
data = {
'ksize': self.ksize,
'alpha': self.alpha,
'id': self.node.id,
'neighbors': self.bootstrappable_neighbors()
}
if not data['neighbors']:
log.warning("No known neighbors, so not writing to cache.")
return
with open(fname, 'wb') as file:
pickle.dump(data, file)
@classmethod
def load_state(cls, fname):
log.info("Loading state from %s", fname)
with open(fname, 'rb') as file:
data = pickle.load(file)
svr = Server(data['ksize'], data['alpha'], data['id'])
if data['neighbors']:
svr.bootstrap(data['neighbors'])
return svr
def save_state_regularly(self, fname, frequency=600):
self.save_state(fname)
loop = asyncio.get_event_loop()
self.save_state_loop = loop.call_later(frequency,
self.save_state_regularly,
fname,
frequency)
def check_dht_value_type(value):
typeset = [
int,
float,
bool,
str,
bytes,
set
]
return type(value) in typeset