def assemble_message_loop(self): """Assemble the payload fragments into a message, and send it to upper layer""" message = bytearray() received_bytes = 0 # Read until end of transmission while True: current_payload = None try: current_payload = self.payload_queue.get(block=True) except queue.Empty: # Whether queue was empty or not, a check to self.terminate_socket_event must be performed either way pass if self.terminate_socket_event.is_set(): break if current_payload == 0x4: self.payload_queue.task_done() # Message ended, put it in the finished_message_queue and reset variables utility.log_message(f"Finished reading the message {message}!", self.log_filename, self.log_file_lock) self.finished_message_queue.put(message, self.partner) message = bytearray() received_bytes = 0 elif current_payload != 0x5: # Add the next payload to the message message[received_bytes:utility.PAYLOAD_SIZE:] = current_payload received_bytes += utility.PAYLOAD_SIZE utility.log_message("Assemble message loop finished!", self.log_filename, self.log_file_lock)
def send_reachability_table(self, ip, port): self.reachability_table_lock.acquire() table_size = len(self.reachability_table) # Should not send an entry with the receiver's own address if (ip, port) in self.reachability_table: table_size -= 1 if table_size <= 0: self.reachability_table_lock.release() return encoded_message = bytearray(PKT_TYPE_SIZE + TUPLE_COUNT_SIZE + TUPLE_SIZE * table_size) # Message type struct.pack_into("!B", encoded_message, 0, PKT_TYPE_UPDATE) # 2 bytes for the amount of tuples struct.pack_into("!H", encoded_message, PKT_TYPE_SIZE, table_size) # Iterate the reachability table, writing each tuple to the encoded_message buffer offset = PKT_TYPE_SIZE + TUPLE_COUNT_SIZE # will to the next empty space in the buffer for (r_ip, r_port), (r_mask, _, r_cost) in self.reachability_table.items(): # Add entry to message only if it does not refer to the receiving node if r_ip == ip and r_port == port: continue ip_tuple = tuple([int(tok) for tok in r_ip.split('.')]) encoded_message[offset:offset + TUPLE_SIZE] = utility.encode_tuple(ip_tuple, r_port, r_mask, r_cost) offset += TUPLE_SIZE self.reachability_table_lock.release() utility.log_message(f"Sending reachability table of {len(encoded_message)} bytes to {ip}:{port}", self) self.send_message(ip, port, encoded_message)
def read_messages_loop(self): while not self.stopper.is_set(): try: message, address = self.sock.recvfrom(BUFFER_SIZE) except socket.timeout: continue except ConnectionResetError: continue if self.ignore_updates.is_set(): # Continue without putting the message in the queue if a flood occurred recently continue message_type = int.from_bytes(message[0:PKT_TYPE_SIZE], byteorder='big', signed=False) if message_type == PKT_TYPE_FLOOD or message_type == PKT_TYPE_DEAD: # Flood messages have more priority and the queue will need to be no matter what so delete it and put # the flood message first with self.message_queue_lock: self.message_queue = queue.Queue() self.message_queue.put((message, address)) else: # All other messages have the same priority with self.message_queue_lock: self.message_queue.put((message, address)) utility.log_message("Finished the read messages loop!", self)
def reset_ignore_updates(self): utility.log_message("Resuming message listening...", self) # Continue reading messages self.ignore_updates.clear() # Awaken neighbors again self.find_awake_neighbors()
def send_cost_change(self, ip, port, new_cost): message = bytearray(1) struct.pack_into("!B", message, 0, PKT_TYPE_COST_CHANGE) new_cost_bytes = bytearray(4) struct.pack_into("!I", new_cost_bytes, 0, new_cost) utility.log_message(f"Sending cost change message of {len(message)} bytes to {ip}:{port}", self) self.send_message(ip, port, message+new_cost_bytes[1:])
def update_reachability_table(self, ip, port, mask, cost, through_node): with self.neighbors_lock: total_cost = cost + self.neighbors[through_node][1] # Write to the reachability table, # as many threads may perform read/write we need to lock it with self.reachability_table_lock: if (ip, port) not in self.reachability_table or self.reachability_table[(ip, port)][2] > total_cost: utility.log_message(f"Changing cost of {ip}:{port} passing through {through_node}.", self) self.reachability_table[(ip, port)] = (mask, through_node, total_cost)
def receive_packet(self): self.sock_read_lock.acquire() packet, address = self.sock.recvfrom(utility.PACKET_SIZE) self.sock_read_lock.release() utility.log_message( f"Received packet {utility.packet_to_string(packet)} with SN={utility.get_sn(packet)} and " f"RN={utility.get_rn(packet)} and ACK={utility.are_flags_set(packet, utility.HEADER_ACK)} " f"from {address}", self.log_filename, self.log_file_lock) return packet, address
def get_current_status(self): status = None self.current_status_lock.acquire() old_status = self.current_status.STATUS_NAME status = self.current_status if old_status != self.current_status.STATUS_NAME: utility.log_message( f"Changed status from {old_status} to {self.current_status.STATUS_NAME}!", self.log_filename, self.log_file_lock) self.current_status_lock.release() return status
def __init__(self, address, sock, sock_lock, finished_message_queue, closed_connections_queue, log_filename, log_file_lock): # Change localhost to 127.0.0.1 from now so the address can be written as the current partner if address[0] == 'localhost': address = ('127.0.0.1', address[1]) # TODO remove self.times_notified = 0 self.times_unblocked = 0 # Logging self.log_filename = log_filename # Socket self.sock = sock # State variables self.current_status = states.ClosedStatus() self.current_sn = 0 self.current_rn = 0 self.partner = address # Queues used only in this socket self.send_queue = queue.Queue() self.receive_queue = queue.Queue() self.payload_queue = queue.Queue() # Queues shared among all connection self.finished_message_queue = finished_message_queue self.closed_connections_queue = closed_connections_queue # Locks used only in this socket self.sock_read_lock = threading.Lock() self.send_queue_lock = threading.Lock() self.current_status_lock = threading.Lock() self.current_sn_lock = threading.Lock() self.current_rn_lock = threading.Lock() # Locks shared among all connection self.sock_send_lock = sock_lock self.log_file_lock = log_file_lock # Events self.terminate_socket_event = threading.Event() # Threads self.handler_thread = threading.Thread(target=self.main_loop) self.handler_thread.start() self.message_assembly_thread = threading.Thread( target=self.assemble_message_loop) self.message_assembly_thread.start() utility.log_message(f"Socket for {address} has been started!", self.log_filename, self.log_file_lock)
def close(self): # Block until all tasks in the queues are done self.send_queue.join() self.receive_queue.join() utility.log_message(f"Close {self.partner}, sending FIN...", self.log_filename, self.log_file_lock) close_packet = utility.create_packet(fin=True) self.increase_current_sn() self.send_packet(close_packet) self.send_queue.put(close_packet) self.set_current_status(states.FinSentStatus)
def send_packet(self, packet): # Write RN and SN packet[1] = self.get_current_rn() packet[2] = self.get_current_sn() utility.log_message( f"Sending packet {utility.packet_to_string(packet)} with SN={utility.get_sn(packet)} and " f"RN={utility.get_rn(packet)} and ACK={utility.are_flags_set(packet, utility.HEADER_ACK)} to " f"{self.get_partner()}", self.log_filename, self.log_file_lock) self.sock_send_lock.acquire() self.sock.sendto(packet, self.get_partner()) self.sock_send_lock.release()
def read_loop(self): """Puts packets in the receive queue""" while not self.terminate_node_event.is_set(): # FIXME: receive_packets recvfrom call blocks packet, address = self.receive_packet() # Randomly drop some packets to test the Stop-And-Wait algorithm # TODO uncomment this # if random.randint(1, 10) == 1: # utility.log_message("Oops! Dropped a packet...", self.log_filename, self.log_file_lock) # If a connection is established with this address, send the packet to that connection self.connections_lock.acquire() if address in self.connections: utility.log_message(f"Routing packet to {address}", self.log_filename, self.log_file_lock) self.connections[address].receive_queue.put(packet) # If new connections are being accepted and the incoming message indicates it want ot initiate a connection # Allocate resources for the connection elif utility.are_flags_set(packet, utility.HEADER_SYN): self.accepting_connections_lock.acquire() if self.accepting_connections: utility.log_message( f"Creating a new socket to handle incoming connection to {address}", self.log_filename, self.log_file_lock) new_connection = PseudoTCPSocket( address, self.sock, self.sock_send_lock, self.finished_messages_queue, self.closed_connections_queue, self.log_filename, self.log_file_lock) new_connection.set_current_status(states.AcceptStatus) new_connection.deliver_packet(packet) self.connections[address] = new_connection else: utility.log_message( f"{address} wanted to initiate a connection, but the node is not accepting " f"connections. Ignoring...", self.log_filename, self.log_file_lock) self.accepting_connections_lock.release() else: utility.log_message( f"The message from {address} didn't have an open connection and didn't contain a " f"SYN, ignoring...", self.log_filename, self.log_file_lock) self.connections_lock.release() utility.log_message("Read loop finished!", self.log_filename, self.log_file_lock)
def initiate_connection(self): utility.log_message(f"Trying to connect to {self.partner}...", self.log_filename, self.log_file_lock) # Build the SYN message, choosing a random value for sn self.set_current_sn(random.randint(0, 255)) syn_message = utility.create_packet(syn=True, sn=self.get_current_sn()) # Send SYN utility.log_message(f"Sending SYN...", self.log_filename, self.log_file_lock) self.send_packet(syn_message) self.set_current_status(states.SynSentStatus()) # Send SYN is the send_queue, as it might need to be resent if the packet is lost self.send_queue.put(syn_message)
def send(self, message, address): address = utility.resolve_localhost(address) utility.log_message(f"Sending a message to {address}...", self.log_filename, self.log_file_lock) # Check if the connection exists self.connections_lock.acquire() if address not in self.connections: utility.log_message( f"Tried sending a message to {address}, but that connection didn't exist!", self.log_filename, self.log_file_lock) return # Send the message self.connections[address].send(message) self.connections_lock.release()
def close_all_connections(self): """ Closes all current connections. The node will not be closed. """ utility.log_message( f"Closing all connections, will wait until all operations are complete...", self.log_filename, self.log_file_lock) self.accepting_connections_lock.acquire() old_value = self.accepting_connections self.accepting_connections = False self.connections_lock.acquire() for address, connection in self.connections.items(): connection.close() # Release locks self.accepting_connections = old_value self.accepting_connections_lock.release() self.connections_lock.release()
def connect(self, address): address = utility.resolve_localhost(address) self.connections_lock.acquire() if address not in self.connections: # Allocate resources for this new connection new_connection = PseudoTCPSocket(address, self.sock, self.sock_send_lock, self.finished_messages_queue, self.closed_connections_queue, self.log_filename, self.log_file_lock) new_connection.initiate_connection() self.connections[address] = new_connection else: utility.log_message( "The connection already exists, please close it first!", self.log_filename, self.log_file_lock) self.connections_lock.release()
def send_keep_alive_loop(self): while not self.stopper.wait(SEND_KEEP_ALIVE_INTERVAL): self.continue_keep_alives.wait() with self.neighbors_lock: for (ip, port), (mask, cost, current_retries, _) in self.neighbors.items(): if current_retries > 0: utility.log_message(f"Sending keep alive to {ip}:{port}...", self) # Create a timer to implement the timeout, will execute code to handle the timeout after it # triggers # If an ack is received this timer will be cancelled timeout_timer = threading.Timer(KEEP_ALIVE_TIMEOUT, self.handle_keep_alive_timeout, [], {"ip": ip, "port": port}) timeout_timer.start() # Save the timer in that neighbor's tuple so it can be retrieved and cancelled if/when necessary self.neighbors[(ip, port)] = (mask, cost, current_retries, timeout_timer) self.send_keep_alive(ip, port) utility.log_message("Finished the send keep alive loop!", self)
def close_connections_loop(self): """ Whenever a connection is going to teminate, it should signal this thread to remove its entry in the connections table. """ while not self.terminate_node_event.is_set(): # Block until a socket is terminated, and delete that entry try: terminated_address = self.closed_connections_queue.get( block=True) except queue.Empty: continue if terminated_address == 0x5: # This message is placed in the queue to wake it up, ignore it continue self.connections_lock.acquire() del self.connections[terminated_address] utility.log_message(f"Closed {terminated_address} successfully!", self.log_filename, self.log_file_lock) self.connections_lock.release()
def handle_keep_alive_timeout(self, **kwargs): # Get the parameters from the kwargs dictionary ip = kwargs["ip"] port = kwargs["port"] with self.neighbors_lock: # Check the neighbor's retry status neighbor = self.neighbors[ip, port] if neighbor[2] == 1: # If decreasing the remaining retries would set it to 0, remove the entry and start a flood utility.log_message(f"Keep alive message to {ip}:{port} timed out! No more retries remaining, deleting " f"entry and starting flood...", self) self.neighbors[ip, port] = (neighbor[0], neighbor[1], neighbor[2] - 1, None) self.remove_reachability_table_entry(ip, port) self.send_flood_message(HOP_NUMBER) elif neighbor[2] > 0: # If the neighbor is not already at 0 retries, decrease the remaining retries self.neighbors[ip, port] = (neighbor[0], neighbor[1], neighbor[2]-1, None) utility.log_message(f"Keep alive message to {ip}:{port} timed out! {neighbor[2]} retries remaining...", self)
def close_connection(self, address): """ Closes one connection in the connections table- :param address: the ip-port pair of the connection to be closed and deleted. """ address = utility.resolve_localhost(address) utility.log_message( f"Closing {address}, will wait until its operations are complete...", self.log_filename, self.log_file_lock) # Check if the connection exists self.connections_lock.acquire() if address not in self.connections: utility.log_message( f"Tried closing {address}, but that connection didn't exist!", self.log_filename, self.log_file_lock) return # Close the connection self.connections[address].close() self.connections_lock.release()
def __init__(self, address): address = utility.resolve_localhost(address) # Socket self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.sock.bind(address) # Logging self.log_filename = f"node_{address[0]}_{address[1]}.txt" # Connections: a dict of pseudoTCP sockets indexed by ip-port pairs self.connections = {} # This class only has one state: accepting connection or not. All other states are handled by the sockets self.accepting_connections = False # Queues self.finished_messages_queue = queue.Queue() self.closed_connections_queue = queue.Queue() # Locks self.accepting_connections_lock = threading.RLock( ) # Reentrant to support close_node()'s logic self.connections_lock = threading.Lock() self.log_file_lock = threading.Lock() self.sock_read_lock = threading.Lock() self.sock_send_lock = threading.Lock() # Events self.terminate_node_event = threading.Event() # Threads self.message_reader = threading.Thread(target=self.read_loop) self.message_reader.start() self.connection_killer = threading.Thread( target=self.close_connections_loop) self.connection_killer.start() utility.log_message(f"Node has been started in {address}!", self.log_filename, self.log_file_lock)
def send_flood_message(self, hops): message = bytearray(2) struct.pack_into("!B", message, 0, PKT_TYPE_FLOOD) struct.pack_into("!B", message, 1, hops) for ip, port in self.neighbors: utility.log_message(f"Sending flood message of {len(message)} bytes to {ip}:{port}", self) self.send_message(ip, port, message) # Set the event to indicate that updates should be ignored self.ignore_updates.set() # Halt keep alives self.continue_keep_alives.clear() # Clear the reachability table and message queue with self.reachability_table_lock: self.reachability_table.clear() with self.message_queue_lock: self.message_queue = queue.Queue() # Start a timer to clear the previous event so updates can continue continue_updates_timer = threading.Timer(IGNORE_AFTER_FLOOD_INTERVAL, self.reset_ignore_updates) continue_updates_timer.start()
def decode_tuples(self, message, origin_node): # Ignore updates that do not originate from a neighbor if origin_node not in self.neighbors: utility.log_message(f"Discarding update from {origin_node[0]}:{origin_node[1]} as it is not a neighbor.", self) return offset = 0 while offset < len(message): # Unpack the binary tuple_bytes = struct.unpack('!BBBBBBBBBB', message[offset:offset + TUPLE_SIZE]) # Get each of the tuple's values ip_bytes = tuple_bytes[:4] ip = f"{ip_bytes[0]}.{ip_bytes[1]}.{ip_bytes[2]}.{ip_bytes[3]}" mask = tuple_bytes[4] port = int.from_bytes(tuple_bytes[5:7], byteorder='big', signed=False) cost = int.from_bytes(tuple_bytes[7:], byteorder='big', signed=False) offset += TUPLE_SIZE # utility.log_message(f"ADDRESS: {ip}, SUBNET MASK: {mask}, COST: {cost}", self) self.update_reachability_table(ip, port, mask, cost, origin_node)
def find_awake_neighbors(self): # Halt infinite keep alives self.continue_keep_alives.clear() # Assume all neighbors are dead with self.unawakened_neighbors_lock: self.unawakened_neighbors = list(self.neighbors.keys()) # Try to find neighbors until they are all alive or the maximum amount of retries is met current_tries = 0 while current_tries < KEEP_ALIVE_RETRIES and self.unawakened_neighbors: current_tries += 1 with self.unawakened_neighbors_lock: for (ip, port) in self.unawakened_neighbors: utility.log_message(f"Waking {ip}:{port}", self) self.send_keep_alive(ip, port) # Sleep for the timeout duration before trying again time.sleep(KEEP_ALIVE_TIMEOUT) with self.unawakened_neighbors_lock: if not self.unawakened_neighbors: utility.log_message("All neighbors have awakened!", self) else: unawakened_str = "Unawoken neighbors: " for (ip, port) in self.unawakened_neighbors: # Set nodes as dead with self.neighbors_lock: neighbor = self.neighbors[ip, port] self.neighbors[ip, port] = (neighbor[0], neighbor[1], 0, None) # Add to string to inform user unawakened_str += f"{ip}:{port} " utility.log_message_force(unawakened_str, self) # Continue infinite keep alives self.continue_keep_alives.set()
def receive_message(self): # Read enough bytes for the message, a standard packet does not exceed 1500 bytes try: message, address = self.message_queue.get(block=True, timeout=SOCKET_TIMEOUT) except queue.Empty: return message_type = int.from_bytes(message[0:PKT_TYPE_SIZE], byteorder='big', signed=False) if message_type == PKT_TYPE_UPDATE: tuple_count = struct.unpack('!H', message[PKT_TYPE_SIZE:PKT_TYPE_SIZE + 2])[0] utility.log_message(f"Received a table update from {address[0]}:{address[1]} of size " f"{len(message)} with {tuple_count} tuples.", self) # Decode the received tuples and update the reachability table if necessary self.decode_tuples(message[PKT_TYPE_SIZE + TUPLE_COUNT_SIZE:], address) elif message_type == PKT_TYPE_KEEP_ALIVE: utility.log_message(f"Received a keep alive from {address[0]}:{address[1]}.", self) self.send_ack_keep_alive(address[0], address[1]) elif message_type == PKT_TYPE_ACK_KEEP_ALIVE: utility.log_message(f"Received a keep alive ack from {address[0]}:{address[1]}.", self) # Check if this is the first time the node has replied if address in self.unawakened_neighbors: with self.unawakened_neighbors_lock: self.unawakened_neighbors.remove(address) with self.neighbors_lock: # Cancel the timer neighbor = self.neighbors[address] try: neighbor[3].cancel() except AttributeError: pass # If the node was thought dead re-add it to the reachability table with self.reachability_table_lock: self.reachability_table[address] = (neighbor[0], address, neighbor[1]) # Reset the retry number self.neighbors[address] = (neighbor[0], neighbor[1], KEEP_ALIVE_RETRIES, None) elif message_type == PKT_TYPE_FLOOD: hops = struct.unpack("!B", message[1:2])[0] utility.log_message_force(f"Received a FLOOD with {hops} hops remaining from {address[0]}:{address[1]}." f"\nFlushing reachability table..." f"\nWill ignore updates for {IGNORE_AFTER_FLOOD_INTERVAL} seconds.", self) # Continue the flood with one less hop if hops < 0: utility.log_message_force(f"Received a flood with too few hops from {address}!") elif hops > 255: utility.log_message_force(f"Received a flood with too many hops from {address}!") elif hops != 0: self.send_flood_message(hops - 1)
def main_loop(self): """This loop will handle the three main events: receiving data from an upper layer, receiving an ack, or a timeout""" while not self.terminate_socket_event.is_set(): try: # This call blocks until an element is available packet = self.receive_queue.get(block=True, timeout=utility.TIMEOUT) self.receive_queue.task_done() except queue.Empty: # Timed out waiting for a packet utility.log_message( f"Timeout! Handling with current status {self.get_current_status().STATUS_NAME}", self.log_filename, self.log_file_lock) self.get_current_status().handle_timeout(self) continue utility.log_message( f"Handling received packet with current status {self.get_current_status().STATUS_NAME}", self.log_filename, self.log_file_lock) self.get_current_status().handle_packet(packet=packet, node=self) utility.log_message("Main loop finished!", self.log_filename, self.log_file_lock)
def send_node_death_message(self, ip, port): message = bytearray(1) struct.pack_into("!B", message, 0, PKT_TYPE_DEAD) utility.log_message(f"Sending node death message of {len(message)} bytes to {ip}:{port}", self) self.send_message(ip, port, message)
def send_keep_alive(self, ip, port): message = bytearray(1) struct.pack_into("!B", message, 0, PKT_TYPE_KEEP_ALIVE) utility.log_message(f"Sending keep alive message of {len(message)} bytes to {ip}:{port}", self) self.send_message(ip, port, message)
def __init__(self, ip, mask, port, neighbors): # Simple data self.port = port self.ip = ip self.mask = mask self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.sock.bind((self.ip, self.port)) self.sock.setblocking(True) self.sock.settimeout(SOCKET_TIMEOUT) # Turns off many frequent prints, will not avoid logging self.print_updates = True # Structures # Reachability table: ip, port : mask, (ip, port), cost self.reachability_table = {} # Neighbors: ip, port : mask, cost, current_retries (0 if node is dead), Timer obj self.neighbors = {} for (n_ip, n_mask, n_port), n_cost in neighbors.items(): self.neighbors[(n_ip, n_port)] = (n_mask, n_cost, 0, None) # Queue to hold incoming messages # Will be flushed when encountering a flood self.message_queue = queue.Queue() # Used when waking nodes self.unawakened_neighbors = list(self.neighbors.keys()) # Locks self.reachability_table_lock = threading.Lock() self.neighbors_lock = threading.Lock() self.message_queue_lock = threading.Lock() self.unawakened_neighbors_lock = threading.Lock() # Events self.stopper = threading.Event() self.ignore_updates = threading.Event() self.continue_keep_alives = threading.Event() # Threads self.connection_handler_thread = threading.Thread(target=self.handle_incoming_connections_loop) self.message_reader_thread = threading.Thread(target=self.read_messages_loop) self.keep_alive_handler_thread = threading.Thread(target=self.send_keep_alive_loop) self.update_handler_thread = threading.Thread(target=self.send_updates_loop) self.command_handler_thread = threading.Thread(target=self.handle_console_commands) # Prints identifying the node utility.log_message(f"Welcome to node {ip}:{port}/{mask}!", self) utility.log_message(f"\nThis node's neighbors:", self) self.print_neighbors_table() utility.log_message("\nAvailable commands are:", self) utility.log_message(" sendMessage <ip> <port> <message>", self) utility.log_message(" exit", self) utility.log_message(" change cost <neighbor ip> <neighbor port> <new cost>", self) utility.log_message(" printOwn", self) utility.log_message(" printTable", self) utility.log_message(" printNeighbors", self) utility.log_message(" prints <on|off>\n", self)
def remove_reachability_table_entry(self, ip, port): with self.reachability_table_lock: if (ip, port) in self.reachability_table: del self.reachability_table[(ip, port)] utility.log_message(f"DISCONNECT: Deleted {ip}:{port} from the reachability table.", self)