forked from albertz/playground
/
ipc_benchmark.py
executable file
·356 lines (304 loc) · 10.7 KB
/
ipc_benchmark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
#!/usr/bin/env python
import sys
import os
import subprocess
import numpy
import ctypes
import ctypes.util
import atexit
import gc
import better_exchook
better_exchook.install()
libc_so = ctypes.util.find_library('c')
libc = ctypes.CDLL(libc_so, use_errno=True)
shm_key_t = ctypes.c_int
IPC_PRIVATE = 0
IPC_RMID = 0
# int shmget(key_t key, size_t size, int shmflg);
shmget = libc.shmget
shmget.restype = ctypes.c_int
shmget.argtypes = (shm_key_t, ctypes.c_size_t, ctypes.c_int)
# void* shmat(int shmid, const void *shmaddr, int shmflg);
shmat = libc.shmat
shmat.restype = ctypes.c_void_p
shmat.argtypes = (ctypes.c_int, ctypes.c_void_p, ctypes.c_int)
# int shmdt(const void *shmaddr);
shmdt = libc.shmdt
shmdt.restype = ctypes.c_int
shmdt.argtypes = (ctypes.c_void_p,)
# int shmctl(int shmid, int cmd, struct shmid_ds *buf);
shmctl = libc.shmctl
shmctl.restype = ctypes.c_int
shmctl.argtypes = (ctypes.c_int, ctypes.c_int, ctypes.c_void_p)
# void* memcpy( void *dest, const void *src, size_t count );
memcpy = libc.memcpy
memcpy.restype = ctypes.c_void_p
memcpy.argtypes = (ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t)
def _darwin_get_sysctl(key):
res = subprocess.check_output(["sysctl", key]).split()
assert len(res) == 2
assert res[0] == "%s:" % key
return int(res[1])
def _darwin_set_sysctl(key, value):
cmd = ["sudo", "sysctl", "-w", "%s=%s" % (key, value)]
print("Calling %r" % cmd)
subprocess.check_call(cmd)
def _darwin_check_sysctl(key, minvalue):
value = _darwin_get_sysctl(key)
if value < minvalue:
print("Value %s = %s < %s" % (key, value, minvalue))
_darwin_set_sysctl(key, minvalue)
ShmmaxWanted = 1024 ** 3
def check_shmmax():
if sys.platform == "darwin":
# https://support.apple.com/en-us/HT4022
_darwin_check_sysctl("kern.sysv.shmmax", ShmmaxWanted)
_darwin_check_sysctl("kern.sysv.shmall", ShmmaxWanted / 4048)
#_darwin_check_sysctl("kern.sysv.shmmni", 256) # Operation not permitted?
#_darwin_check_sysctl("kern.sysv.shmseg", 256)
# Maybe call cleanup_shared_mem.py?
else:
print("check_shmmax not implemented for platform %r" % sys.platform)
def check_ccall_error(check, f):
if not check:
errno = ctypes.get_errno()
errstr = os.strerror(errno)
raise Exception("%s failed with error %i (%s)" % (f, errno, errstr))
class SharedMem:
def __init__(self, size, shmid=None):
self.size = size
self.shmid = None
self.ptr = None
if shmid is None:
self.is_creator = True
self.shmid = shmget(IPC_PRIVATE, self.size, 0o600)
check_ccall_error(self.shmid > 0, "shmget")
print("New shmid: %i" % self.shmid)
atexit.register(self.remove)
else:
self.is_creator = False
self.shmid = shmid
assert self.shmid > 0
self.ptr = shmat(self.shmid, 0, 0)
check_ccall_error(self.ptr != ctypes.c_void_p(-1).value, "shmat")
check_ccall_error(self.ptr > 0, "shmat")
def remove(self):
if self.ptr:
shmdt(self.ptr)
self.ptr = None
if self.shmid and self.is_creator:
shmctl(self.shmid, IPC_RMID, 0)
print("Removed shmid %i" % self.shmid)
self.shmid = None
def __del__(self):
self.remove()
def __getstate__(self):
return {"size": self.size, "shmid": self.shmid}
def __setstate__(self, state):
self.__init__(**state)
def __repr__(self):
return "<SharedMem shmid=%r size=%r is_creator=%r>" % (self.shmid, self.size, self.is_creator)
def next_power_of_two(n):
return 2 ** ((n - 1).bit_length())
class SharedNumpyArray:
# cls members
ServerInstances = set()
MaxServerInstances = 2
ExtraSpaceBytes = 4048
# local members
is_server = False
mem = None
shape, strides, typestr = None, None, None
array = None
@classmethod
def needed_mem_size(cls, shape, typestr):
itemsize = int(typestr[2:])
mem_size = cls.ExtraSpaceBytes + itemsize * numpy.prod(shape)
return mem_size
@classmethod
def create_copy(cls, array):
assert isinstance(array, numpy.ndarray)
array_intf = array.__array_interface__
shape = array_intf["shape"]
strides = array_intf["strides"]
typestr = array_intf["typestr"]
inst = cls.create_new(shape=shape, strides=strides, typestr=typestr)
assert array.nbytes == inst.array.nbytes
memcpy(inst.array.ctypes.data, array.ctypes.data, array.nbytes)
return inst
@classmethod
def create_new(cls, shape, strides, typestr):
needed_mem_size = cls.needed_mem_size(shape=shape, typestr=typestr)
for inst in cls.ServerInstances:
assert isinstance(inst, SharedNumpyArray)
if inst.is_in_use(): continue
if inst.mem.size < needed_mem_size:
inst._init_mem(shape=shape, typestr=typestr)
inst._set_is_used(1)
inst._create_numpy(shape=shape, strides=strides, typestr=typestr)
return inst
return cls(shape=shape, strides=strides, typestr=typestr)
@classmethod
def create_from_shared(cls, shape, strides, typestr, mem):
return cls(shape=shape, strides=strides, typestr=typestr, mem=mem)
def __init__(self, shape, strides, typestr, mem=None):
if not mem:
assert len(self.ServerInstances) < self.MaxServerInstances
self.is_server = True
self._init_mem(shape=shape, typestr=typestr)
self._set_is_used(1)
self.ServerInstances.add(self)
else:
self.is_server = False
mem_size = self.needed_mem_size(shape=shape, typestr=typestr)
assert isinstance(mem, SharedMem)
assert mem.size >= mem_size
assert mem.shmid > 0
assert mem.ptr > 0
self.mem = mem
self._create_numpy(shape=shape, strides=strides, typestr=typestr)
def _init_mem(self, shape, typestr):
assert self.is_server
if self.mem:
self.mem.remove()
self.mem = None
mem_size = next_power_of_two(self.needed_mem_size(shape=shape, typestr=typestr))
self.mem = SharedMem(size=mem_size)
def _create_numpy(self, shape, strides, typestr):
assert self.mem.ptr > 0
self.shape = shape
self.strides = strides
self.typestr = typestr
# http://docs.scipy.org/doc/numpy/reference/arrays.interface.html
array_intf = {
"data": (self.mem.ptr + self.ExtraSpaceBytes, False),
"shape": shape,
"strides": strides,
'typestr': typestr,
"version": 3
}
class A:
_base = self
__array_interface__ = array_intf
a = numpy.array(A, copy=False)
assert not a.flags.owndata
assert a.base is A
self.array = a
def _get_in_use_flag_ref(self):
assert self.mem.ptr > 0
return ctypes.cast(ctypes.c_void_p(self.mem.ptr), ctypes.POINTER(ctypes.c_uint64)).contents
def _set_is_used(self, n):
self._get_in_use_flag_ref().value = n
def is_in_use(self):
return self._get_in_use_flag_ref().value > 0
def set_unused(self):
self.array = None
if self.mem:
self._set_is_used(0)
self.mem.remove()
self.mem = None
def __getstate__(self):
return {
"shape": self.shape, "strides": self.strides, "typestr": self.typestr,
"mem": self.mem
}
def __setstate__(self, state):
self.__init__(**state)
def __del__(self):
# On the server side, we will get deleted at program end
# because we are referenced in the global SharedNumpyArray.server_instances.
# On the client side, we will get deleted once we are not used anymore.
# Note that self.array holds a reference to self.
self.set_unused()
def __repr__(self):
return "<%s is_server=%r state=%r>" % (self.__class__.__name__, self.is_server, self.__getstate__())
def pickle_std(s, v):
import pickle
pickler = pickle.Pickler(file=s, protocol=-1)
pickler.dump(v)
s.flush()
def pickle_ext(s, v):
import extpickle
pickler = extpickle.Pickler(file=s)
pickler.dump(v)
s.flush()
def pickle_shm(s, v):
import pickle
assert isinstance(v, tuple)
if len(v) == 2:
assert isinstance(v[1], numpy.ndarray)
shared = SharedNumpyArray.create_copy(v[1])
shared.array = None
v = (v[0], shared)
pickler = pickle.Pickler(file=s, protocol=-1)
pickler.dump(v)
s.flush()
pickle = pickle_std
def unpickle(s):
import pickle
unpickler = pickle.Unpickler(file=s)
v = unpickler.load()
assert isinstance(v, tuple)
if len(v) == 2 and isinstance(v[1], SharedNumpyArray):
a = v[1]
assert a.array is not None
array = a.array
a.array = None
return (v[0], array)
return v
LoopCount = 10
MatrixSize = 5000
def demo():
if pickle is pickle_shm:
check_shmmax()
p = subprocess.Popen([__file__] + sys.argv[1:] + ["--client"], stdin=subprocess.PIPE, stdout=subprocess.PIPE)
for i in range(LoopCount):
m = numpy.random.randn(MatrixSize, MatrixSize)
pickle(p.stdin, ("ping", m))
out, m2 = unpickle(p.stdout)
assert out == "pong"
assert isinstance(m2, numpy.ndarray)
assert m.shape == m2.shape
assert numpy.isclose(m, m2).all()
del m2
gc.collect()
print("Copying done, exiting.")
pickle(p.stdin, ("exit",))
out, = unpickle(p.stdout)
assert out == "exit"
p.wait()
print("Done. Return code %i" % p.returncode)
def demo_client():
in_stream = sys.stdin
out_stream = sys.stdout
sys.stdout = sys.stderr
print("Hello from client!")
while True:
cmd = unpickle(in_stream)
assert isinstance(cmd, tuple)
if cmd[0] == "exit":
pickle(out_stream, ("exit",))
break
elif cmd[0] == "ping":
assert isinstance(cmd[1], numpy.ndarray)
pickle(out_stream, ("pong", cmd[1]))
else:
assert False, "unknown: %r" % cmd
del cmd
gc.collect()
print("Exit from client!")
if __name__ == "__main__":
if sys.argv[1] == "--shared_mem":
pickle = pickle_shm
elif sys.argv[1] == "--pickle":
pickle = pickle_std
elif sys.argv[1] == "--extpickle":
pickle = pickle_ext
else:
assert False, "unknown args: %r" % sys.argv[1:]
if sys.argv[2:] == ["--client"]:
demo_client()
elif sys.argv[2:] == []:
demo()
else:
assert False, "unknown args: %r" % sys.argv[1:]