def __generate_and_send_triplets(self, op_position: int, clients, shapes): # 判断哪一个是乘数,哪一个是被乘数 if op_position == 2: shapes = [shapes[1], shapes[0]] clients = [clients[1], clients[0]] u0 = np.random.uniform(-1, 1, shapes[0]) u1 = np.random.uniform(-1, 1, shapes[0]) v0 = np.random.uniform(-1, 1, shapes[1]) v1 = np.random.uniform(-1, 1, shapes[1]) z = np.matmul(u0 + u1, v0 + v1) z0 = z * np.random.uniform(0, 1, z.shape) z1 = z - z0 try: self.send_check_msg( clients[0], PackedMessage(MessageType.Triplet_Array, (clients[1], u0, v0, z0), clients[1])) except: self.logger.logE("Sending triplets to client %d failed" % clients[0]) try: self.send_check_msg( clients[1], PackedMessage(MessageType.Triplet_Array, (clients[0], v1, u1, z1), clients[0])) except: self.logger.logE("Sending triplets to client %d failed" % clients[1])
def GetComputationData(self, request, context): msg = decode_ComputationData(request) if self.msg_handler(msg, request.client_id): return encode_ComputationData( PackedMessage(header=MessageType.RECEIVED_OK, data=None)) else: return encode_ComputationData( PackedMessage(header=MessageType.RECEIVED_ERR, data="Buffer occupied, send it later"))
def multiply_BA_with(self, client_id: int, triplet_id: int, client_shape: tuple, matB: np.ndarray): try: self.send_check_msg( triplet_id, PackedMessage(MessageType.Triplet_Set, (2, client_id, matB.shape, client_shape))) triplet_msg = self.receive_check_msg(triplet_id, MessageType.Triplet_Array, key=client_id).data shared_V_self, shared_U_self, shared_W_self = triplet_msg[1:] except: self.logger.logE( "Get triplet arrays failed. Stop multiplication with client %d." % client_id) return False try: shared_B_self = matB * np.random.uniform(0, 1, matB.shape) shared_B_other = matB - shared_B_self self.send_check_msg( client_id, PackedMessage(MessageType.MUL_Mat_Share, shared_B_other)) shared_A_self = self.receive_check_msg( client_id, MessageType.MUL_Mat_Share).data except: self.logger.logE( "Swap mat shares with client. Stop multiplication with client %d." % client_id) return False try: shared_A_sub_U_self = shared_A_self - shared_U_self self.send_check_msg( client_id, PackedMessage(MessageType.MUL_AsubU_Share, shared_A_sub_U_self)) shared_A_sub_U_other = self.receive_check_msg( client_id, MessageType.MUL_AsubU_Share).data shared_B_sub_V_self = shared_B_self - shared_V_self self.send_check_msg( client_id, PackedMessage(MessageType.MUL_BsubV_Share, shared_B_sub_V_self)) shared_B_sub_V_other = self.receive_check_msg( client_id, MessageType.MUL_BsubV_Share).data except: self.logger.logE( "Swap A-U and B-V with client failed. Stop multiplication with client %d" % client_id) return False A_sub_U = shared_A_sub_U_self + shared_A_sub_U_other B_sub_V = shared_B_sub_V_self + shared_B_sub_V_other self.product = shared_U_self @ B_sub_V + A_sub_U @ shared_V_self + shared_W_self return True
def compare_split_node(self): client_outs = self.broadcaster.receive_all(self.feature_client_ids, MessageType.XGBOOST_GAIN) if self.broadcaster.error: self.logger.logE("Gather clients' outputs failed. Stop training") return False # send ack msg to the client which has the min loss selected = None for data_client in self.feature_client_ids: if selected is None or client_outs[selected] > client_outs[ data_client]: selected = data_client msgs = dict() for data_client in self.feature_client_ids: if data_client == selected: msgs[data_client] = PackedMessage( MessageType.XGBOOST_SELECTED_NODE, 'ack') else: msgs[data_client] = PackedMessage( MessageType.XGBOOST_SELECTED_NODE, 'rej') self.broadcaster.broadcast(self.feature_client_ids, msgs) if self.broadcaster.error: self.logger.logE("Broadcast split info failed. Stop training") return False if not self.get_residual(selected): return False preds = self.broadcaster.receive_all(self.feature_client_ids, MessageType.XGBOOST_PRED_LABEL) if self.broadcaster.error: self.logger.logE( "Get predict for each epoch failed. Stop Training") return False pred = np.zeros((self.raw_label_data.shape[0])) for k, v in preds.items(): pred += v pred = pred.reshape(-1, 1) try: metric = self.metric_func(self.raw_label_data, pred) msg = PackedMessage(MessageType.XGBOOST_PRED_LABEL, (self.label_data.mean(), metric)) self.send_check_msg(self.label_client_id, msg) except: self.logger.logE( "Send predict label to label client fail. Stop Training") return False return True
def send(self, receiver: int, msg: PackedMessage, time_out: float = None): """ :param receiver: 接收方 :param msg: 消息 :param time_out: 发送的最大时长,超出该时长未发送成功则停止。如果未设置则采用默认延时。 :return: """ if not time_out: time_out = self.time_out resp = PackedMessage(MessageType.RECEIVED_ERR, None) start_send_time = time.time() while resp.header == MessageType.RECEIVED_ERR: if time.time() - start_send_time > time_out: self.logger.log( "Timeout while sending to client %d. Time elapsed %.3f" % (receiver, time_out)) return resp try: resp = self.rpc_clients[receiver].sendComputationMessage( msg, self.client_id, time_out) except Exception as e: self.logger.logE("Error while sending to client %d" % receiver + ":" + str(e)) return resp return resp
def _broadcast_start(self, stop=False): if not stop: header = MessageType.XGBOOST_NEXT_TRAIN_ROUND else: header = MessageType.XGBOOST_TRAINING_STOP start_data = self.label_data self.broadcaster.broadcast(self.feature_client_ids + [self.label_client_id], PackedMessage(header, start_data))
def _before_training(self): try: self._build_mlp_model(self.in_dim, self.layers) except: self.logger.logE("Build tensorflow model failed. Stop training.") return False self.client_dims = self.broadcaster.receive_all( self.feature_client_ids, MessageType.SharedNN_ClientDim) if self.broadcaster.error: self.logger.logE("Gather clients' dims failed. Stop training.") return False train_config = { "client_dims": self.client_dims, "out_dim": self.in_dim, "batch_size": self.batch_size, "test_batch_size": self.test_batch_size, "learning_rate": self.learning_rate } self.broadcaster.broadcast( self.feature_client_ids + [self.label_client_id], PackedMessage(MessageType.SharedNN_TrainConfig, train_config)) if self.broadcaster.error: self.logger.logE( "Broadcast training config message failed. Stop training.") return False return True
def predict(self): self.logger.log("MainClient predict start...") # get test label try: self.send_check_msg(self.label_client_id, PackedMessage(MessageType.XGBOOST_TRAIN, False)) y_true = self.receive_check_msg(self.label_client_id, MessageType.XGBOOST_LABEL).data except: self.logger.logE("Get test label from label client failed. Stop predict.") return False y_preds = np.zeros((y_true.shape[0])) y_pred_dict = self.broadcaster.receive_all(self.feature_client_ids, MessageType.XGBOOST_PRED_LABEL) if self.broadcaster.error: self.logger.logE("Gather predict y failed") return False for y_pred in y_pred_dict.values(): y_preds[:] += y_pred # print(y_preds) y_preds = y_preds.reshape(-1, 1) res = self.metric_func(y_true, y_preds) self.logger.log("Predict metric {}".format(res))
def _before_training(self): if not super(FeatureClient, self)._before_training(): return False try: self.send_check_msg( self.main_client_id, PackedMessage(MessageType.SharedNN_ClientDim, self.data_dim)) config = self.receive_check_msg( self.main_client_id, MessageType.SharedNN_TrainConfig).data self.logger.log( "Received main client's config message: {}".format(config)) self.other_feature_dims = config["client_dims"] self.output_dim = config["out_dim"] self.batch_size = config["batch_size"] self.test_batch_size = config["test_batch_size"] self.learning_rate = config["learning_rate"] except: self.logger.logE( "Get training config from server failed. Stop training.") return False try: for client_id in self.other_feature_dims: other_dim = self.other_feature_dims[client_id] self.other_paras[client_id] = np.random.normal( 0, 1 / (len(self.feature_client_ids) * other_dim), [other_dim, self.output_dim]) except: self.logger.logE("Initialize weights failed. Stop training.") return False self.para = np.random.normal( 0, 1 / (len(self.feature_client_ids) * self.data_dim), [self.data_dim, self.output_dim]) return True
def send_gain_to(client_id: int, update: float): try: self.send_check_msg( client_id, PackedMessage(MessageType.XGBOOST_GAIN, update)) except: self.logger.logE( "Error encountered while sending gain to other client") self.error = True
def __send_aligned_ids_to(self, client_id): try: self.send_check_msg( client_id, PackedMessage(MessageType.ALIGN_FINAL_IDS, self.aligned_ids)) except: self.logger.logE("Sending aligned ids to client %d failed" % client_id) self.error = True
def send_pred(server_id, update): try: self.send_check_msg( server_id, PackedMessage(MessageType.XGBOOST_PRED_LABEL, update)) except: self.logger.logE( "Error encountered while sending label to other client") self.error = True
def __send_aes_key_to(self, client_id): try: self.send_check_msg( client_id, PackedMessage(MessageType.ALIGN_AES_KEY, (self.shared_aes_key, self.shared_aes_iv))) except: self.logger.logE("Send raw aes key to client %d failed." % client_id) self.error = True
def _get_raw_abel_data(self): try: self.send_check_msg(self.label_client_id, PackedMessage(MessageType.XGBOOST_TRAIN, True)) self.label_data = self.receive_check_msg(self.label_client_id, MessageType.XGBOOST_LABEL).data self.raw_label_data = deepcopy(self.label_data) except: self.logger.logE("Receive label from label client failed. Stop training.") return False return True
def _backward(self): try: grad_on_output, status = self.receive_check_msg( self.main_client_id, MessageType.SharedNN_FeatureClientGrad).data except: self.logger.logE( "Receive grads message from main client failed. Stop training." ) return False if grad_on_output is None: pass else: own_para_grad = self.batch_data.transpose() @ grad_on_output portion = np.random.uniform(0, 1, len(self.feature_client_ids)) portion /= np.sum(portion) self.para -= self.learning_rate * own_para_grad * portion[0] other_grad_msgs = dict() current_portion = 1 for other_id in self.other_feature_client_ids: other_grad_msgs[other_id] = PackedMessage( MessageType.SharedNN_FeatureClientParaGrad, own_para_grad * portion[current_portion]) current_portion += 1 try: self.broadcaster.broadcast(self.other_feature_client_ids, other_grad_msgs) self_para_grads = self.broadcaster.receive_all( self.other_feature_client_ids, MessageType.SharedNN_FeatureClientParaGrad) for other_id in self_para_grads: self.other_paras[ other_id] -= self.learning_rate * self_para_grads[ other_id] except: self.logger.logE( "Swap parameter gradients with other clients failed. Stop training." ) return False if status == "Stop": self.logger.log( "Received Stop signal from main-client, stop training.") self.finished = True elif status == "Continue-Test": self.mpc_mode = MPCC.ClientMode.Test else: self.mpc_mode = MPCC.ClientMode.Train return True
def send_label_to_main(self): try: header = MessageType.XGBOOST_LABEL train = self.receive_check_msg(self.main_client_id, MessageType.XGBOOST_TRAIN).data if train: data = self.label else: data = self.test_label self.send_check_msg(self.main_client_id, PackedMessage(header, data)) except: self.logger.logE("Send Label to main client predictions failed. Stop training.") return False return True
def _before_training(self): # send config dict to every data client self.broadcaster.broadcast(self.feature_client_ids + [self.label_client_id], PackedMessage(MessageType.XGBOOST_TRAIN_CONFIG, self.config )) if self.broadcaster.error: self.logger.logE("Broadcast training config message failed. Stop training.") return False # # if not self._get_raw_abel_data(): # return False # self.broadcaster.receive_all(self.feature_client_ids, MessageType.CLIENT_READY) return True
def _compute_loss(self): try: preds, mode = self.receive_check_msg( self.main_client_id, MessageType.SharedNN_MainClientOut).data if mode == "Train" or mode == "Train-Stop": self.mpc_mode = MPCC.ClientMode.Train self.batch_data = self.train_data_loader.get_batch( self.batch_size) elif mode == "Test" or mode == "Test-Stop": self.mpc_mode = MPCC.ClientMode.Test self.logger.log( "Received Test signal. Load test set data to batch.") self.batch_data = self.test_data_loader.get_batch( self.test_batch_size) if mode[-4:] == "Stop": self.logger.log( "Received Stop signal. Stop training after finish this round." ) self.finished = True loss = self.loss_func.forward(self.batch_data, preds) metric = self.metric_func(self.batch_data, preds) if self.mpc_mode is MPCC.ClientMode.Test: self.test_record.append( [time.time() - self.start_time, self.n_rounds, loss] + metric) self.logger.log( "Current batch: {} loss: {}, metric value: {}".format( self.n_rounds, loss, metric)) grad = self.loss_func.backward() self.send_check_msg( self.main_client_id, PackedMessage(MessageType.SharedNN_MainClientGradLoss, (grad, loss))) except: self.logger.logE( "Compute gradient for main client predictions failed. Stop training." ) return False return True
def _compute_loss(self): try: mode_str = "Train" if self.mpc_mode is MPCC.ClientMode.Test: mode_str = "Test" if self.n_rounds == self.max_iter: self.finished = True mode_str += "-Stop" self.send_check_msg( self.label_client_id, PackedMessage(MessageType.SharedNN_MainClientOut, (self.network_out.numpy(), mode_str))) self.grad_on_output = self.receive_check_msg( self.label_client_id, MessageType.SharedNN_MainClientGradLoss).data[0] except: self.logger.logE( "Get gradients from label client failed. Stop training.") return False return True
def _before_training(self): self.logger.log("Start sync random seed with other data-providers.") random_seed = np.random.randint(0, 999999999) other_data_client_ids = self.feature_client_ids + [ self.label_client_id ] other_data_client_ids.remove(self.client_id) self.broadcaster.broadcast( other_data_client_ids, PackedMessage(MessageType.SharedNN_RandomSeed, random_seed)) all_seeds = self.broadcaster.receive_all( other_data_client_ids, MessageType.SharedNN_RandomSeed) if not self.broadcaster.error: for c in all_seeds: random_seed ^= all_seeds[c] self.train_data_loader.set_random_seed(random_seed) self.test_data_loader.set_random_seed(random_seed) self.logger.log("Random seed swapped and set to data loaders.") return True else: self.logger.log( "Swapping random seed with other data-providers failed. Stop training." ) return False
def _backward(self): if self.mpc_mode is MPCC.ClientMode.Train: grad_on_output = self.grad_on_output if len(self.network.trainable_variables) != 0: model_jacobians = self.gradient_tape.jacobian( self.network_out, self.network.trainable_variables) model_grad = [ tf.reduce_sum( model_jacobian * (tf.reshape( grad_on_output.astype(np.float32), list(grad_on_output.shape) + [1 for _ in range(len(model_jacobian.shape) - 2)]) + tf.zeros_like(model_jacobian, dtype=model_jacobian.dtype)), axis=[0, 1]) for model_jacobian in model_jacobians ] self.optimizer.apply_gradients( zip(model_grad, self.network.trainable_variables)) input_jacobian = self.gradient_tape.jacobian( self.network_out, self.input_tensor) input_grad = tf.reduce_sum( input_jacobian * (tf.reshape( grad_on_output.astype(np.float32), list(grad_on_output.shape) + [1 for i in range(len(input_jacobian.shape) - 2)]) + tf.zeros_like(self.input_tensor, dtype=self.input_tensor.dtype)), axis=[0, 1]).numpy() else: input_grad = None if self.n_rounds == self.max_iter: self.broadcaster.broadcast( self.feature_client_ids, PackedMessage(MessageType.SharedNN_FeatureClientGrad, (input_grad, "Stop"))) try: self.send_check_msg( self.crypto_producer_id, PackedMessage(MessageType.Common_Stop, None, key="Stop")) except: self.logger.logW( "Send stop message to triplet provider failed.") elif (self.n_rounds + 1) % self.test_per_batches == 0: self.mpc_mode = MPCC.ClientMode.Test self.broadcaster.broadcast( self.feature_client_ids, PackedMessage(MessageType.SharedNN_FeatureClientGrad, (input_grad, "Continue-Test"))) else: self.mpc_mode = MPCC.ClientMode.Train self.broadcaster.broadcast( self.feature_client_ids, PackedMessage(MessageType.SharedNN_FeatureClientGrad, (input_grad, "Continue-Train"))) if self.broadcaster.error: self.logger.logE( "Broadcast feature client's grads failed. Stop training.") return False return True
def start_align(self): self.logger.log("Start generating random keys and ivs") self.generate_raw_key_and_iv() self.__send_aes_keys() if self.error: self.logger.logE("Send aes keys failed, stop align") return False self.__receive_aes_keys() if self.error: self.logger.logE("Receive aes keys failed, stop align") return False self.__generate_shared_key() self.logger.log("Shared key generated. Start encrypt ids.") try: self.__load_and_enc_data() except: self.logger.logE( "Error while loading and encrypting data. Stop align") return False self.logger.log("Start sending encrypted ids to main preprocessor") try: self.send_check_msg( self.main_client_id, PackedMessage(MessageType.ALIGN_ENC_IDS, self.sample_ids)) except: self.logger.logE( "Error while sending aligned ids to align server. Stop align") return False try: msg = self.receive_check_msg(self.main_client_id, MessageType.ALIGN_FINAL_IDS) except: self.logger.logE( "Error while receiving ids intersection from align_server") return False self.logger.log("Received aligned ids. Start making aligned data.") encrypted_ids = msg.data selected_ids = list() for sample_id in encrypted_ids: selected_ids.append( AES.new(self.shared_aes_key, AES.MODE_CBC, self.shared_aes_iv).decrypt(sample_id).decode( 'utf8').replace('\0', '')) aligned_data = self.data.loc[selected_ids] file_name = pathlib.Path(self.filepath).name self.out_indexed_file = self.out_dir + file_name[:-4] + "_aligned.csv" aligned_data.to_csv(self.out_indexed_file, index=True) test_size = int(len(selected_ids) / 5) train_data = aligned_data.iloc[:-test_size] test_data = aligned_data.iloc[-test_size:] self.train_data_path = self.out_dir + "train.csv" self.test_data_path = self.out_dir + "test.csv" if self.cols is None: self.cols = self.data.columns if isinstance(self.cols[0], int): self.cols = self.data.columns[self.cols] if not set(self.cols) <= set(train_data.columns): self.logger.logE("Selected columns not in the csv.") return False train_data[self.cols].to_csv(self.train_data_path, header=False, index=False) test_data[self.cols].to_csv(self.test_data_path, header=False, index=False) self.logger.log("Align finished, aligned file saved in " + self.out_dir) return True
def _local_train_one_round(self): """本地调用xgb计算一轮,发送gain信息""" # todo 本地调用xgb计算一轮,发送gain信息 # send gain def send_gain_to(client_id: int, update: float): try: self.send_check_msg(client_id, PackedMessage(MessageType.XGBOOST_GAIN, update)) except: self.logger.logE("Error encountered while sending gain to other client") self.error = True # recevie select_node select_node = None def receive_node(client_id: int): try: nonlocal select_node select_node = self.receive_check_msg(client_id, MessageType.XGBOOST_SELECTED_NODE).data except: self.logger.logE("Error encountered while receiving selected info from client %d" % client_id) self.error = True X_train = self.data X_test = self.test_data y_train = self.label_data y = self.raw_label # self.mean = np.mean(y) # y_pred = np.ones_like(y) * self.mean # calculate loss if self.raw_label.all() == self.label_data.all(): loss = LogLoss(self.raw_label, np.ones_like(self.raw_label) * np.mean(self.raw_label)) else: loss = LogLoss(self.raw_label, self.raw_label - self.label_data) g, h = loss.g(), loss.h() # print('h:', h) estimator_t = CART(reg_lambda=self.reg_lambda, max_depth=self.max_depth, gamma=self.gamma, col_sample_ratio=self.col_sample_ratio, row_sample_ratio=self.row_sample_ratio) estimator_t.fit(X_train, self.label_data, g, h) # print(y_pred.shape) # print(np.expand_dims(estimator_t.predict(X_train), axis=1).shape) label = self.label_data - (self.learning_rate * np.expand_dims(estimator_t.predict(X_train), axis=1)) # print(y_pred) # the smaller the better # print(estimator_t.obj_val) # tinyxgb_clf = XGBClassifier(self.configs) # print(y_train) # tinyxgb_clf.fit(X_train, y_train) gain = estimator_t.obj_val sending_th = threading.Thread(target=send_gain_to, args=(self.main_client_id, gain,)) sending_th.start() sending_th.join() receving_th = threading.Thread(target=receive_node, args=(self.main_client_id,)) receving_th.start() receving_th.join() # label = y_train - y_pred # print('{} selected: {}'.format(self.client_id, select_node)) if select_node == 'ack': print('epoch ---- loss', np.mean(loss.forward()), '----client_id', self.client_id) self.estimators.append(estimator_t) # predict self.send_check_msg(self.main_client_id, PackedMessage(MessageType.XGBOOST_RESIDUAL, label))
def decode_ComputationData(computation_data: message_pb2.ComputationData): data, key = pickle.loads(computation_data.python_bytes) return PackedMessage(MessageType(computation_data.type), data, key)
def _forward(self): if self.mpc_mode is MPCC.ClientMode.Train: self.batch_data = self.train_data_loader.get_batch(self.batch_size) else: self.logger.log("Test round: load data from test dataset") self.batch_data = self.test_data_loader.get_batch( self.test_batch_size) def thread_mul_with_client(client_id: int): multiplier = self.multipliers[client_id] if self.client_id < client_id: if not multiplier.multiply_AB_with( client_id, self.crypto_producer_id, (self.data_dim, self.output_dim), self.batch_data): self.error = True return self.shared_out_AB[client_id] = multiplier.product if not multiplier.multiply_BA_with( client_id, self.crypto_producer_id, (self.batch_data.shape[0], self.other_feature_dims[client_id]), self.other_paras[client_id]): self.error = True return self.shared_out_BA[client_id] = multiplier.product else: if not multiplier.multiply_BA_with( client_id, self.crypto_producer_id, (self.batch_data.shape[0], self.other_feature_dims[client_id]), self.other_paras[client_id]): self.error = True return self.shared_out_BA[client_id] = multiplier.product if not multiplier.multiply_AB_with( client_id, self.crypto_producer_id, (self.data_dim, self.output_dim), self.batch_data): self.error = True return self.shared_out_AB[client_id] = multiplier.product mul_threads = [] for client_id in self.other_feature_client_ids: mul_threads.append( threading.Thread(target=thread_mul_with_client, args=(client_id, ))) mul_threads[-1].start() for mul_thread in mul_threads: mul_thread.join() if self.error: self.logger.logE( "Multiplication on shared parameters with other clients failed. Stop training." ) return False self.own_out = self.batch_data @ self.para try: self.send_check_msg( self.main_client_id, PackedMessage( MessageType.SharedNN_FeatureClientOut, (self.own_out, self.shared_out_AB, self.shared_out_BA))) except: self.logger.logE( "Send outputs to main client failed. Stop training.") return False return True