def send_result_to_neighbors(self, receive_id, client_params1): logging.info("send_result_to_neighbors. receive_id = " + str(receive_id)) message = Message(MyMessage.MSG_TYPE_SEND_MSG_TO_NEIGHBOR, self.get_sender_id(), receive_id) message.add_params(MyMessage.MSG_ARG_KEY_PARAMS_1, client_params1) self.send_message(message)
def send_message_sync_model_to_client(self, receive_id, global_logits): message = Message(MyMessage.MSG_TYPE_S2C_SYNC_TO_CLIENT, self.get_sender_id(), receive_id) message.add_params(MyMessage.MSG_ARG_KEY_GLOBAL_LOGITS, global_logits) self.send_message(message) logging.info("send_message_sync_model_to_client. Receive_id: " + str(receive_id))
def send_model_to_server(self, receive_id, weights, local_sample_num, train_evaluation_metrics, test_evaluation_metrics): message = Message(MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.get_sender_id(), receive_id) message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, weights) message.add_params(MyMessage.MSG_ARG_KEY_NUM_SAMPLES, local_sample_num) message.add_params(MyMessage.MSG_ARG_KEY_TRAIN_EVALUATION_METRICS, train_evaluation_metrics) message.add_params(MyMessage.MSG_ARG_KEY_TEST_EVALUATION_METRICS, test_evaluation_metrics) self.send_message(message)
def _notify(self, msg): # print("_notify: " + msg) msg_params = Message() msg_params.init_from_json_string(str(msg)) msg_type = msg_params.get_type() for observer in self._observers: observer.receive_message(msg_type, msg_params)
def send_model_to_server(self, receive_id, weights, local_sample_num, num_bits, latent_weight): message = Message(MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.get_sender_id(), receive_id) message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, weights) message.add_params(MyMessage.MSG_ARG_KEY_NUM_SAMPLES, local_sample_num) message.add_params(MyMessage.MSG_ARG_KEY_NUM_BITS, num_bits) message.add_params(MyMessage.MSG_ARG_KEY_LATENT_WEIGHT, latent_weight) self.send_message(message)
def send_message(self, message): msg = Message() msg.add(Message.MSG_ARG_KEY_TYPE, message.get_type()) msg.add(Message.MSG_ARG_KEY_SENDER, message.get_sender_id()) msg.add(Message.MSG_ARG_KEY_RECEIVER, message.get_receiver_id()) for key, value in message.get_params().items(): # logging.info("%s == %s" % (key, value)) msg.add(key, value) self.com_manager.send_message(msg)
def run(self): logging.debug("Starting Thread:" + self.name + ". Process ID = " + str(self.rank)) while True: try: msg_str = self.comm.recv() msg = Message() msg.init(msg_str) self.q.put(msg) except Exception: traceback.print_exc()
def send_model_to_server(self, receive_id, extracted_feature_dict, logits_dict, labels_dict, extracted_feature_dict_test, labels_dict_test): message = Message(MyMessage.MSG_TYPE_C2S_SEND_FEATURE_AND_LOGITS, self.get_sender_id(), receive_id) message.add_params(MyMessage.MSG_ARG_KEY_FEATURE, extracted_feature_dict) message.add_params(MyMessage.MSG_ARG_KEY_LOGITS, logits_dict) message.add_params(MyMessage.MSG_ARG_KEY_LABELS, labels_dict) message.add_params(MyMessage.MSG_ARG_KEY_FEATURE_TEST, extracted_feature_dict_test) message.add_params(MyMessage.MSG_ARG_KEY_LABELS_TEST, labels_dict_test) self.send_message(message)
def send_model_to_server(self, receive_id, host_train_logits, host_test_logits): message = Message(MyMessage.MSG_TYPE_C2S_LOGITS, self.get_sender_id(), receive_id) message.add_params(MyMessage.MSG_ARG_KEY_TRAIN_LOGITS, host_train_logits) message.add_params(MyMessage.MSG_ARG_KEY_TEST_LOGITS, host_test_logits) self.send_message(message)
def send_message(self, msg: Message): """ [server] sending message topic (publish): serverID_clientID receiving message topic (subscribe): clientID [client] sending message topic (publish): clientID receiving message topic (subscribe): serverID_clientID """ if self.client_id == 0: # server receiver_id = msg.get_receiver_id() topic = self._topic + str(0) + "_" + str(receiver_id) logging.info("topic = %s" % str(topic)) payload = msg.to_json() self._client.publish(topic, payload=payload) logging.info("sent") else: # client self._client.publish(self._topic + str(self.client_id), payload=msg.to_json())
def send_message_init_config(self, receive_id, global_model_params, client_index): message = Message(MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id) message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(client_index)) self.send_message(message)
def __send_initial_config_to_client(self, process_id, global_model_params, global_arch_params): message = Message(MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), process_id) message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) message.add_params(MyMessage.MSG_ARG_KEY_ARCH_PARAMS, global_arch_params) logging.info("MSG_TYPE_S2C_INIT_CONFIG. receiver: " + str(process_id)) self.send_message(message)
def send_message_sync_model_to_client(self, receive_id, global_model_params, client_index): logging.info("send_message_sync_model_to_client. receive_id = %d" % receive_id) message = Message(MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, self.get_sender_id(), receive_id) message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(client_index)) self.send_message(message)
def send_message_init_config(self, receive_id, global_model_params, client_index): logging.info( 'Initial Configurations sent to client {0}'.format(client_index)) message = Message(MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id) message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(client_index)) self.send_message(message)
def __send_model_to_client_message(self, process_id, global_model_params, global_arch_params): message = Message(MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, 0, process_id) message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) message.add_params(MyMessage.MSG_ARG_KEY_ARCH_PARAMS, global_arch_params) logging.info( "__send_model_to_client_message. MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT. receiver: " + str(process_id)) self.send_message(message)
def send_semaphore_to_client(self, receive_id): message = Message(MyMessage.MSG_TYPE_C2C_SEMAPHORE, self.get_sender_id(), receive_id) self.send_message(message)
def send_message_to_client(self, receive_id, global_result): message = Message(MyMessage.MSG_TYPE_S2C_INFORMATION, self.get_sender_id(), receive_id) message.add_params(MyMessage.MSG_ARG_KEY_INFORMATION, global_result) self.send_message(message)
def send_message_init_config(self, receive_id): message = Message(MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id) self.send_message(message)
def send_model_to_server(self, receive_id, client_gradient): message = Message(MyMessage.MSG_TYPE_C2S_INFORMATION, self.get_sender_id(), receive_id) message.add_params(MyMessage.MSG_ARG_KEY_INFORMATION, client_gradient) self.send_message(message)
def send_message_init_config(self, receive_id, global_model_params): message = Message(MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id) self.send_message(message) logging.info("send_message_init_config. Receive_id: " + str(receive_id))
def send_grads_to_client(self, receive_id, grads): message = Message(MyMessage.MSG_TYPE_S2C_GRADS, self.get_sender_id(), receive_id) message.add_params(MyMessage.MSG_ARG_KEY_GRADS, grads) self.send_message(message)
def __send_msg_fedavg_send_model_to_server(self, weights, alphas, local_sample_num, valid_acc, valid_loss): message = Message(MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.rank, 0) message.add_params(MyMessage.MSG_ARG_KEY_NUM_SAMPLES, local_sample_num) message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, weights) message.add_params(MyMessage.MSG_ARG_KEY_ARCH_PARAMS, alphas) message.add_params(MyMessage.MSG_ARG_KEY_LOCAL_TRAINING_ACC, valid_acc) message.add_params(MyMessage.MSG_ARG_KEY_LOCAL_TRAINING_LOSS, valid_loss) self.send_message(message)
def send_finish_to_server(self, receive_id): message = Message(MyMessage.MSG_TYPE_C2S_PROTOCOL_FINISHED, self.get_sender_id(), receive_id)
def send_activations_and_labels_to_server(self, acts, labels, receive_id): message = Message(MyMessage.MSG_TYPE_C2S_SEND_ACTS, self.get_sender_id(), receive_id) message.add_params(MyMessage.MSG_ARG_KEY_ACTS, (acts, labels)) self.send_message(message)
def send_validation_over_to_server(self, receive_id): message = Message(MyMessage.MSG_TYPE_C2S_VALIDATION_OVER, self.get_sender_id(), receive_id) self.send_message(message)
else: # client self._client.publish(self._topic + str(self.client_id), payload=msg.to_json()) def handle_receive_message(self): pass def stop_receive_message(self): pass if __name__ == '__main__': class Obs(Observer): def receive_message(self, msg_type, msg_params) -> None: print("receive_message(%s, %s)" % (msg_type, msg_params.to_string())) client = MqttCommManager("81.71.1.31", 1883) client.add_observer(Obs()) time.sleep(3) print('client ID:%s' % client.client_id) message = Message(0, 1, 2) message.add_params("key1", 1) client.send_message(message) time.sleep(10) print("client, send Fin...")