def update(self, aggregate_update): self.model.load_state_dict(aggregate_update) # Compute Accuracy (test) acc = self.compute_accuracy(self.test_loader) self.test_acc.append(acc) if debug_level >= DEBUG_LEVEL.INFO: TERM.write('\tEpoch ' + str(len(self.test_acc) - 1) + '\n') TERM.write('\tClass Accuracies: {}'.format( 100 * np.array(self.test_acc[-1]))) # Occasionally save current test accuracy self.save_to_csv(acc, './train_curves/Server.csv')
def listen_for_clients(self): # TODO: make number bigger/variable self.listener_sock.listen(5) while self.listening: # Accept connection. client_sock, client_addr = self.listener_sock.accept() if debug_level >= DEBUG_LEVEL.INFO: TERM.write_success( 'Established connection to {}'.format(client_addr)) # cache the client socket. self.connected_clients_by_addr[client_addr] = client_sock self.connected_clients_by_sock[client_sock] = client_addr
def __init__(self, host): if debug_level >= DEBUG_LEVEL.INFO: TERM.write_info('Started server at ' + str(host)) # the socket to listen for connections on. self.listener_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.listener_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.listener_sock.bind(host) self.listener_process = None self.listening = False self.client_lock = threading.Lock() # the set of connected clients. self.connected_clients_by_sock = {} self.connected_clients_by_addr = {}
def connect(self, TIMEOUT): if debug_level >= DEBUG_LEVEL.INFO: TERM.write_info( 'Attempting to connect to {} (Timeout: {}s)'.format( self.server, TIMEOUT)) # Retry to connect to the server start = time.time() while (time.time() - start) < TIMEOUT: self.attempt_to_connect(TIMEOUT) if self.connected: break # Run the client (if connection succesful) if not self.connected: if debug_level >= DEBUG_LEVEL.INFO: TERM.write_failure( 'Time limit exceeded: {} not found'.format(SERVER)) else: if debug_level >= DEBUG_LEVEL.INFO: TERM.write_success('Successfully connected to {} as {}'.format( SERVER, self.sock.getsockname())) self.run()
def train(self): # Optimization Settings criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr, momentum=self.momentum) start = time.time() for epoch in range(self.num_epochs): running_loss = 0.0 for i, (inputs, targets) in enumerate(self.train_loader): # Enable CUDA if self.use_cuda and torch.cuda.is_available(): inputs = inputs.cuda() targets = targets.cuda() # Forward Pass outputs = self.model(inputs) loss = criterion(outputs, targets) # Backward Pass loss.backward() optimizer.step() optimizer.zero_grad() # Accumulate the loss running_loss += loss.item() if debug_level >= DEBUG_LEVEL.INFO: TERM.write('\tEpoch ' + str(epoch + 1)) train_acc_list, train_acc = self.evaluate_accuracy(self.train_loader) test_acc_list, test_acc = self.evaluate_accuracy(self.test_loader) if debug_level >= DEBUG_LEVEL.INFO: TERM.write('\tTraining Accuracy: {0:0.2f}'.format(train_acc)) TERM.write('\tTesting Accuracy: {0:0.2f}'.format(test_acc)) with open('./train_curves/Client{}.csv'.format(self.digits), 'a', newline='') as csv_file: writer = csv.writer(csv_file) writer.writerow(test_acc_list) end = time.time() if debug_level >= DEBUG_LEVEL.INFO: TERM.write('\t%0.2f minutes' %((end - start) / 60))
def run(self): while True: # Wait for weights from the FLServer if debug_level >= DEBUG_LEVEL.INFO: TERM.write_info('Waiting for model from server...(Timeout: ' + str(self.TIMEOUT) + ')') start_time = time.time() weights = None while ((weights == None) and (time.time() - start_time < self.TIMEOUT)): # get the weights. weights = Communication_Handler.recv_msg(self.sock) if weights == None: if debug_level >= DEBUG_LEVEL.INFO: TERM.write_warning( "Time Limit Exceeded: Weights not received.") else: if debug_level >= DEBUG_LEVEL.INFO: TERM.write_success("Weights received.") TERM.write_info("Training local model...") # Load weights self.trainer.load_weights(weights) # Train model self.trainer.train() if debug_level >= DEBUG_LEVEL.INFO: TERM.write_success("Training complete.") TERM.write_info("Sending update to server...") # Compute focused update update = self.trainer.focused_update() # Send update to the server Communication_Handler.send_msg(self.sock, update) if debug_level >= DEBUG_LEVEL.INFO: TERM.write_success("Update sent.")
def train(self): while len(self.connected_clients_by_addr) > 0: # select a subset of the clients and broadcast the model. if debug_level >= DEBUG_LEVEL.INFO: TERM.write_info("Selecting clients...") self.select_clients(self.subset_size) if debug_level >= DEBUG_LEVEL.INFO: TERM.write_info("Broadcasting model...") self.broadcast_model() if debug_level >= DEBUG_LEVEL.INFO: TERM.write_info("Waiting for updates...(Timeout: " + str(self.TIMEOUT) + ")") start = time.time() # wait for each client's update and then aggregate. while (self.aggregated_update is None) and time.time() - start < self.TIMEOUT: self.wait_for_updates() self.attempt_to_aggregate_updates() # if an aggregated update has been created... if self.aggregated_update is not None: # Stop communication (temporarily) self.stop() if debug_level >= DEBUG_LEVEL.INFO: TERM.write_success("All updates received.") TERM.write_info("Aggregating updates...") # Update the model using aggregated update. self.update_model(self.aggregated_update) self.aggregated_update = None # reset the selected client address list (to be re-selected) self.selected_clients_by_addr = {} self.selected_clients_updates = {} if debug_level >= DEBUG_LEVEL.INFO: TERM.write_info("Sending aggregated update to clients...") # Restart communication. self.start() else: if debug_level >= DEBUG_LEVEL.INFO: TERM.write_warning( "Time-limit exceeded: Some updates were not received.") if debug_level >= DEBUG_LEVEL.INFO: TERM.write_failure("All clients disconected.")
### Main Code ### BUFFER_TIME = 5 if __name__ == '__main__': # the socket for the server. server_hostname = socket.gethostbyname('localhost') server_port = 8080 # Initialize the FL server. flServer = FLServer((server_hostname, server_port), server_trainer.ServerTrainer()) # Allow client to connect flServer.start() # Buffer for clients to connect TERM.write_info('Waiting for clients to connect...(Timeout: ' + str(BUFFER_TIME) + 's)') time.sleep(BUFFER_TIME) if len(flServer.connected_clients_by_addr) == 0: TERM.write_failure("Time limit exceeed: No clients connected.") else: # Train the FL server model. TERM.write_warning('Time limit exceeded: ' + str(len(flServer.connected_clients_by_addr)) + ' client(s) connected.') TERM.write_info("Starting FL training loop...") flServer.train()
# Compute predictions with torch.no_grad(): outputs = self.model(inputs) preds = outputs.max(1, keepdim=True)[1] # determine the number correct per class. labels, counts = torch.unique(targets[(preds.squeeze() == targets).nonzero()], return_counts=True) correct_by_class[labels] += counts.float() # determine the number per class. labels, counts = torch.unique(targets, return_counts=True) total_by_class[labels] += counts.float() total_by_class[(total_by_class == 0).nonzero()] = 1.0 # TODO: Change this to be an NaN. acc = 100 * correct_by_class.sum() / total_by_class.sum() return (correct_by_class / total_by_class).cpu().tolist(), float(acc.cpu()) # TEST if __name__ == '__main__': trainer = ClientTrainer([1, 2, 3], use_cuda=True) trainer.load_weights(trainer.model.state_dict()) TERM.write('Weights loaded successfully!') trainer.train() TERM.write('Model trained successfully!') TERM.write('Focused update computed successfully!')