def recv_command(self): # receive: ENQ (1 byte) + commandcode (2) + length (4) try: start = b'' while len(start) < 7: data = self.sock.recv(7 - len(start)) if not data: raise ProtocolError('recv_command: connection broken') start += data if start[0:1] != ENQ: raise ProtocolError('recv_command: invalid command header') # it has a length... length, = LENGTH.unpack(start[3:]) buf = b'' while len(buf) < length: read = self.sock.recv(min(READ_BUFSIZE, length-len(buf))) if not read: raise ProtocolError('recv_command: connection broken') buf += read try: return self.serializer.deserialize_cmd( buf, code2command[start[1:3]]) except Exception as err: self.send_error_reply('invalid command or garbled data') raise ProtocolError('recv_command: invalid command or ' 'garbled data') except socket.error as err: raise ProtocolError('recv_command: connection broken (%s)' % err)
def recv_reply(self): # receive first byte + (possibly) length start = b'' while len(start) < 5: data = self.sock.recv(5 - len(start)) if not data: raise ProtocolError('connection broken') start += data if start == ACK: return True, None if start[0:1] not in (NAK, STX): raise ProtocolError('invalid response %r' % start) # it has a length... length, = LENGTH.unpack(start[1:]) buf = b'' while len(buf) < length: read = self.sock.recv(READ_BUFSIZE) if not read: raise ProtocolError('connection broken') buf += read if not self.serializer: self.serializer = self.determine_serializer(buf, start[0:1] == STX) # XXX: handle errors return self.serializer.deserialize_reply(buf, start[0:1] == STX)
def recv_event(self): # receive STX (1 byte) + eventcode (2) + nblobs(1) + length (4) start = b'' while len(start) < 8: data = self.event_sock.recv(8 - len(start)) if not data: raise ProtocolError('read: event connection broken') start += data if start[0:1] != STX: raise ProtocolError('wrong event header') nblobs = ord(start[3:4]) length, = LENGTH.unpack(start[4:]) got = 0 # read into a pre-allocated buffer to avoid copying lots of data # around several times buf = bytearray(length) buf_view = memoryview(buf) while got < length: read = self.event_sock.recv_into(buf_view[got:], length - got) if not read: raise ProtocolError('read: event connection broken') got += read # XXX: error handling event = code2event[start[1:3]] data = self.serializer.deserialize_event(buf, event) blobs = [self._recv_blob() for _ in range(nblobs)] return data + (blobs, )
def recv_reply(self): # receive first byte + (possibly) length start = b'' while len(start) < 5: data = self.sock.recv(5 - len(start)) if not data: raise ProtocolError('connection broken') start += data if start == ACK: return True, None if start[0:1] not in (NAK, STX): raise ProtocolError('invalid response %r' % start) # it has a length... length, = LENGTH.unpack(start[1:]) buf = b'' while len(buf) < length: read = self.sock.recv(READ_BUFSIZE) if not read: raise ProtocolError('connection broken') buf += read if not self.serializer: # determine serializer class automatically for serializercls in SERIALIZERS.values(): try: candidate = serializercls() candidate.deserialize_reply(buf, start[0:1] == STX) except Exception: continue self.serializer = candidate break else: # no serializer found raise ProtocolError('no serializer found for this connection') # XXX: handle errors return self.serializer.deserialize_reply(buf, start[0:1] == STX)
def recv_event(self): # receive STX (1 byte) + eventcode (2) + length (4) start = b'' while len(start) < 7: data = self.event_sock.recv(7 - len(start)) if not data: raise ProtocolError('read: event connection broken') start += data if start[0:1] != STX: raise ProtocolError('wrong event header') length, = LENGTH.unpack(start[3:]) got = 0 # read into a pre-allocated buffer to avoid copying lots of data # around several times buf = np.zeros(length, 'c') # Py3: replace with bytearray+memoryview while got < length: read = self.event_sock.recv_into(buf[got:], length - got) if not read: raise ProtocolError('read: event connection broken') got += read # XXX: error handling event = code2event[start[1:3]] # serialized or raw event data? if DAEMON_EVENTS[event][0]: data = self.serializer.deserialize_event(buf.tostring(), event) else: data = event, memory_buffer(buf) return data
def send_error_reply(self, reason): try: data = self.serializer.serialize_error_reply(reason) except Exception as err: raise ProtocolError('send_error_reply: could not serialize') try: self.sock.sendall(NAK + LENGTH.pack(len(data)) + data) except socket.error as err: raise ProtocolError('send_error_reply: connection broken (%s)' % err)
def send_ok_reply(self, payload): try: if payload is None: self.sock.sendall(ACK) else: try: data = self.serializer.serialize_ok_reply(payload) except Exception as err: raise ProtocolError('send_ok_reply: could not serialize') self.sock.sendall(STX + LENGTH.pack(len(data)) + data) except socket.error as err: raise ProtocolError('send_ok_reply: connection broken (%s)' % err)
def _recv_blob(self): start = b'' while len(start) < 4: data = self.event_sock.recv(4 - len(start)) if not data: raise ProtocolError('read: event connection broken') start += data length, = LENGTH.unpack(start) got = 0 buf = np.zeros(length, 'c') # Py3: replace with bytearray+memoryview while got < length: read = self.event_sock.recv_into(buf[got:], length - got) if not read: raise ProtocolError('read: event connection broken') got += read return buf
def get_event(self): item = self.event_sock.recv_multipart() if len(item) < 3: raise ProtocolError('invalid frames received') event = from_utf8(item[1]) # serialized or raw event data? if DAEMON_EVENTS[event][0]: return self.serializer.deserialize_event(item[2], item[1]) else: return item[2]
def recv_reply(self): item = self.sock.recv_multipart() if len(item) < 3: raise ProtocolError('invalid frames received') if not self.serializer: self.serializer = self.determine_serializer(item[2], item[0] == b'ok') return self.serializer.deserialize_reply(item[2], item[0] == b'ok')
def connect(self, conndata, eventmask=None): """Connect to a NICOS daemon. *conndata* is a ConnectionData object. *eventmask* is a tuple of event names that should not be sent to this client. """ self.disconnecting = False if self.isconnected: raise RuntimeError('client already connected') try: self.transport.connect(conndata) except socket.error as err: msg = err.args[1] if len(err.args) >= 2 else str(err) self.signal('failed', 'Server connection failed: %s.' % msg, err) return except Exception as err: self.signal('failed', 'Server connection failed: %s.' % err, err) return # read banner try: success, banner = self.transport.recv_reply() if not success: raise ProtocolError('invalid response format') if 'daemon_version' not in banner: raise ProtocolError('daemon version missing from response') daemon_proto = banner.get('protocol_version', 0) if daemon_proto != PROTO_VERSION: if daemon_proto in COMPATIBLE_PROTO_VERSIONS: self.compat_proto = daemon_proto else: raise ProtocolError('daemon uses protocol %d, but this ' 'client requires protocol %d, do you ' 'need to update NICOS?' % (daemon_proto, PROTO_VERSION)) except Exception as err: self.signal( 'failed', 'Server (%s:%d) handshake failed: %s.' % (conndata.host, conndata.port, err), err) return # log-in sequence self.isconnected = True password = conndata.password pw_hashing = banner.get('pw_hashing', 'sha1') if pw_hashing[0:4] == 'rsa,': if rsa is not None: encodedkey = banner.get('rsakey', None) if encodedkey is None: raise ProtocolError( 'rsa requested, but rsakey missing in banner') if not PY2 and not isinstance(encodedkey, bytes): encodedkey = bytes(encodedkey, 'utf-8') pubkey = rsa.PublicKey.load_pkcs1(b64decode(encodedkey)) password = rsa.encrypt(to_utf8(password), pubkey) password = '******' + b64encode(password).decode() else: pw_hashing = pw_hashing[4:] if pw_hashing == 'sha1': password = hashlib.sha1(to_utf8(password)).hexdigest() elif pw_hashing == 'md5': password = hashlib.md5(to_utf8(password)).hexdigest() credentials = { 'login': conndata.user, 'passwd': password, 'display': '', } response = self.ask('authenticate', credentials) if not response: self._close() return self.user_level = response['user_level'] if eventmask: self.tell('eventmask', eventmask) self.transport.connect_events(conndata) # start event handler self.event_thread = createThread('event handler', self.event_handler) self.host, self.port = conndata.host, conndata.port self.login = conndata.user self.viewonly = conndata.viewonly self.daemon_info = banner self.signal('connected')
def recv_event(self): item = self.event_sock.recv_multipart() if len(item) < 3: raise ProtocolError('invalid frames received') event = item[1].decode() return self.serializer.deserialize_event(item[2], event) + (item[3:],)