def encryptMessage(msg, keys): nonce = nacl.utils.random(24) encryptedBytes = nacl.bindings.crypto_secretbox(msg.SerializeToString(), nonce, keys[1]) response = nstp_v3_pb2.NSTPMessage() response.encrypted_message.ciphertext = encryptedBytes response.encrypted_message.nonce = nonce return response
def EncryptAndSend(s, res): encr_res = nstp_v3_pb2.EncryptedMessage() nonce = utils.random(secret.SecretBox.NONCE_SIZE) encr_res.ciphertext = crypto_secretbox(res.SerializeToString(), nonce, dict_session_keys[s][1]) encr_res.nonce = nonce nstp_res = nstp_v3_pb2.NSTPMessage() nstp_res.encrypted_message.CopyFrom(encr_res) print(nstp_res) response_message[s] = append_len(nstp_res.SerializeToString())
def sendServerHello(msg): global serverPublicKey if msg.client_hello.major_version != 3: return error_message("Wrong version") response = nstp_v3_pb2.NSTPMessage() response.server_hello.major_version = 3 response.server_hello.minor_version = 1 response.server_hello.user_agent = "hello client" response.server_hello.public_key = bytes(serverPublicKey) return response
def serverHello(s, user_agent): print("inside server hello") nstp_res = nstp_v3_pb2.NSTPMessage() s_hello = nstp_v3_pb2.ServerHello() s_hello.major_version = 3 s_hello.minor_version = 2 s_hello.user_agent = user_agent s_hello.public_key = pk_server nstp_res.server_hello.CopyFrom(s_hello) response_message[s] = append_len(nstp_res.SerializeToString())
def recv_input(c): nstp_msg = nstp_v3_pb2.NSTPMessage() data = c.recv(4096) response = -1 if data: len_hex = hex(data[0]) + format(data[1], 'x') if len(data) == int(len_hex, 0) + 2: nstp_msg.ParseFromString(data[2:]) print(nstp_msg) response = decider(nstp_msg) return response
def send(sock, obj, encrypt=True): print("Sent: {0}".format(obj)) bytes_to_send = obj.SerializeToString() if encrypt: global client_tx nonce = randombytes(crypto_secretbox_NONCEBYTES) encrypted_bytes = crypto_secretbox(bytes_to_send, nonce, client_tx) nstp_message = nstp_v3_pb2.NSTPMessage() nstp_message.encrypted_message.ciphertext = encrypted_bytes nstp_message.encrypted_message.nonce = nonce bytes_to_send = nstp_message.SerializeToString() sock.sendall( len(bytes_to_send).to_bytes(2, byteorder="big") + bytes_to_send)
def server_hello(client_hello): print("Client_hello aagaya") key_pair_generator(client_hello.public_key) if client_hello.major_version == 3: server_hello_response = nstp_v3_pb2.NSTPMessage() server_hello_response.server_hello.major_version = 3 server_hello_response.server_hello.minor_version = 2 server_hello_response.server_hello.user_agent = "Client_Authentication" server_hello_response.server_hello.public_key = SessionKeys.get( "server_pk") print(SessionKeys.get("server_pk")) len_hex = bytes.fromhex("{:04x}".format( server_hello_response.ByteSize())) return len_hex + server_hello_response.SerializeToString() return -1
def auth_request_handler(msg): decrypted_message = nstp_v3_pb2.DecryptedMessage() authenticated = authenticator(msg.auth_request.username, msg.auth_request.password) if authenticated == -1: return -1 decrypted_message.auth_response.authenticated = authenticated nonce = nacl.bindings.randombytes( nacl.bindings.crypto_secretbox_NONCEBYTES) ciphertext = nacl.bindings.crypto_secretbox( decrypted_message.SerializeToString(), nonce, SessionKeys.get('server_tx')) auth_response = nstp_v3_pb2.NSTPMessage() auth_response.encrypted_message.ciphertext = ciphertext auth_response.encrypted_message.nonce = nonce len_hex = bytes.fromhex("{:04x}".format(auth_response.ByteSize())) return len_hex + auth_response.SerializeToString()
def load_request_handler(msg): if msg.load_request.public == False: value = SessionKeys.user_store.get(msg.load_request.key) else: value = public_store.get(msg.load_request.key) if value == None: value = b'' decrypted_message = nstp_v3_pb2.DecryptedMessage() decrypted_message.load_response.value = value nonce = nacl.bindings.randombytes( nacl.bindings.crypto_secretbox_NONCEBYTES) ciphertext = nacl.bindings.crypto_secretbox( decrypted_message.SerializeToString(), nonce, SessionKeys.get('server_tx')) load_response = nstp_v3_pb2.NSTPMessage() load_response.encrypted_message.ciphertext = ciphertext load_response.encrypted_message.nonce = nonce len_hex = bytes.fromhex("{:04x}".format(load_response.ByteSize())) return len_hex + load_response.SerializeToString()
def process_response(sock): # message= message received header = sock.recv(2) message_length, = struct.unpack('>H', header) message = sock.recv(message_length) nstp_message = nstp_v3_pb2.NSTPMessage() nstp_message.ParseFromString(message) message_type = nstp_message.WhichOneof('message_') if message_type == 'server_hello': process_server_hello(nstp_message) print(nstp_message) elif message_type == 'encrypted_message': process_encrypted_message(nstp_message) elif message_type == 'error_message': print(nstp_message) else: print("Got unkwown NSTP message")
def store_request_handler(msg): key = msg.store_request.key value = msg.store_request.value if msg.store_request.public == False: SessionKeys.user_store[key] = value else: public_store[key] = value decrypted_message = nstp_v3_pb2.DecryptedMessage() decrypted_message.store_response.hash = hashlib.sha256(value).digest() decrypted_message.store_response.hash_algorithm = 1 nonce = nacl.bindings.randombytes( nacl.bindings.crypto_secretbox_NONCEBYTES) ciphertext = nacl.bindings.crypto_secretbox( decrypted_message.SerializeToString(), nonce, SessionKeys.get('server_tx')) store_response = nstp_v3_pb2.NSTPMessage() store_response.encrypted_message.ciphertext = ciphertext store_response.encrypted_message.nonce = nonce len_hex = bytes.fromhex("{:04x}".format(store_response.ByteSize())) return len_hex + store_response.SerializeToString()
def ping_request_handler(msg): hash_algo = msg.ping_request.hash_algorithm decrypted_message = nstp_v3_pb2.DecryptedMessage() if hash_algo == 0: decrypted_message.ping_response.hash = msg.ping_request.data if hash_algo == 1: decrypted_message.ping_response.hash = hashlib.sha256( msg.ping_request.data).digest() if hash_algo == 2: decrypted_message.ping_response.hash = hashlib.sha512( msg.ping_request.data).digest() nonce = nacl.bindings.randombytes( nacl.bindings.crypto_secretbox_NONCEBYTES) ciphertext = nacl.bindings.crypto_secretbox( decrypted_message.SerializeToString(), nonce, SessionKeys.get('server_tx')) ping_response = nstp_v3_pb2.NSTPMessage() ping_response.encrypted_message.ciphertext = ciphertext ping_response.encrypted_message.nonce = nonce len_hex = bytes.fromhex("{:04x}".format(ping_response.ByteSize())) return len_hex + ping_response.SerializeToString()
def error_message(reason): response = nstp_v3_pb2.NSTPMessage() response.error_message.error_message = reason return response
def connection_thread(c, addr): global serverPublicKey global serverSecretKey global IPtoPreauth print("REMOTE: ", addr[0]) remote = addr[0] clientPublicKey = b'' lengthInBytes = recv_all(c, 2) if len(lengthInBytes) == 0: c.close() lock.acquire() IPtoPreauth[remote] -= 1 lock.release() return 0 length = struct.unpack("!H", lengthInBytes)[0] msg = recv_all(c, length) read = nstp_v3_pb2.NSTPMessage() read.ParseFromString(msg) print(read) end = False attempts = 0 authenticated = False user = "" if read.HasField("client_hello"): clientPublicKey = read.client_hello.public_key if clientPublicKey == b'': response = error_message("Must include a public_key") sentMsg = response.SerializeToString() sentLen = struct.pack("!H", len(sentMsg)) c.sendall(sentLen + sentMsg) lock.acquire() IPtoPreauth[remote] -= 1 lock.release() c.close() return 0 response = sendServerHello(read) try: keys = nacl.bindings.crypto_kx_server_session_keys(serverPublicKey.encode(), serverSecretKey.encode(), clientPublicKey) except nacl.exceptions.CryptoError: response = error_message("Session Key failure") end = True else: response = error_message("Must send a client hello first") end = True sentMsg = response.SerializeToString() sentLen = struct.pack("!H", len(sentMsg)) c.sendall(sentLen + sentMsg) if end: lock.acquire() IPtoPreauth[remote] -= 1 lock.release() c.close() return 0 while True: lengthInBytes = recv_all(c, 2) if len(lengthInBytes) == 0: break print(lengthInBytes) length = struct.unpack("!H", lengthInBytes)[0] msg = recv_all(c, length) #print(msg) read = nstp_v3_pb2.NSTPMessage() read.ParseFromString(msg) print("READ", read) plaintextResponse = "" if read.HasField("encrypted_message"): decryptedMsg = decryptMessage(read, keys) if decryptedMsg.HasField("error_message"): plaintextResponse = decryptedMsg elif decryptedMsg.HasField("auth_request"): lock.acquire() openConnections = IPtoPreauth[remote] lock.release() attempts += 1 if attempts > 40: plaintextResponse = error_message("Too many attempts on this connection") IPtoPreauth[remote] -= 1 elif attempts > 5: sleepTime = abs(openConnections) time.sleep(sleepTime) print("ERROR - too many attempts. Sleeping for: ", sleepTime) plaintextResponse, user, authenticated = messageType(decryptedMsg, authenticated, user, remote) else: plaintextResponse, user, authenticated = messageType(decryptedMsg, authenticated, user, remote) else: if authenticated: plaintextResponse, user, authenticated = messageType(decryptedMsg, authenticated, user, remote) else: plaintextResponse = error_message("Must be authenticated first") print("PLAINTEXT RESPONSE\n", plaintextResponse) print("AUTHENTICATED\n", authenticated, " ", user) response = encryptMessage(plaintextResponse, keys) else: print("wrong message type set") plaintextResponse = error_message("Wrong message type sent") response = encryptMessage(plaintextResponse, keys) sentMsg = response.SerializeToString() sentLen = struct.pack("!H", len(sentMsg)) c.sendall(sentLen + sentMsg) if plaintextResponse.HasField("error_message"): print("Connection with client has been closed") break c.close() print("total connections: ", IPtoPreauth) print("returning out of thread ", addr[0]) return 0
decrypted_message_type = decrypted_message.WhichOneof('message_') print(decrypted_message) if __name__ == '__main__': server_address = ('localhost', 22300) global client_public, client_private client_public, client_private = crypto_kx.crypto_kx_keypair() messages = list() cases = list() ############################################################ # Test case0: check out-of-spec protocol m0 = nstp_v3_pb2.NSTPMessage() m0.client_hello.major_version = 1000 m0.client_hello.minor_version = 1 m0.client_hello.user_agent = 'The user' m0.client_hello.public_key = client_public m0 = (m0, False) m1 = nstp_v3_pb2.NSTPMessage() m1.client_hello.major_version = 3 m1.client_hello.minor_version = 1 m1.client_hello.user_agent = 'The user' m1.client_hello.public_key = client_public m1 = (m1, False) # Uncomment this line to use this message # messages.append(m0)
) clientsocket.close() else: list_ip_failed_logins[clientsocket.getpeername()[0]] = 0 list_ip_authenticated[clientsocket] = 0 clientsocket.setblocking(False) incoming.append(clientsocket) else: print("data read from socket", s) data = s.recv(2) if data: outgoing.append(s) len_msg = struct.unpack('!H', data[:2]) print("length of message to be received {}".format(len_msg[0])) full_msg = recv_full_msg(len_msg[0], s) nstp_msg = nstp_v3_pb2.NSTPMessage() nstp_msg.ParseFromString(full_msg) print(nstp_msg) switcher = { 'client_hello': handleClientHello, 'encrypted_message': handleEncryptedMessage } func = switcher.get(nstp_msg.WhichOneof("message_")) func(s, nstp_msg) else: incoming.remove(s) for s in writes: print("inside write") try: print("writing data to socket---->", s) if s in response_message:
def error_message(): error = nstp_v3_pb2.NSTPMessage() error.error_message.error_message = "I am terminating you" len_hex = bytes.fromhex("{:04x}".format(error.ByteSize())) return len_hex + error.SerializeToString()
msg = b'' while n > 0: chunk = s.recv(n) n = n - len(chunk) msg = msg + chunk return msg HOST = 'localhost' PORT = 22300 s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.connect((HOST, PORT)) pk_client, sk_client = crypto_kx.crypto_kx_keypair() #Clienthello nstp_msg = nstp_v3_pb2.NSTPMessage() ch = nstp_v3_pb2.ClientHello() ch.major_version = 1 ch.minor_version = 2 ch.user_agent = "hi server" ch.public_key = pk_client nstp_msg.client_hello.CopyFrom(ch) print("sending data") s.send(append_len(nstp_msg.SerializeToString())) data = s.recv(2) print("receiving data") len_msg = struct.unpack('!H', data[:2]) print("length of message to be received {}".format(len_msg[0])) full_msg = recv_full_msg(len_msg[0], s) res_msg = nstp_v3_pb2.NSTPMessage()
def plaintextErrorRes(s, msg): err = nstp_v3_pb2.ErrorMessage() err.error_message = msg nstp_res = nstp_v3_pb2.NSTPMessage() nstp_res.error_message.CopyFrom(err) response_message[s] = append_len(nstp_res.SerializeToString())