Пример #1
0
 def __generate_and_send_triplets(self, op_position: int, clients, shapes):
     # 判断哪一个是乘数,哪一个是被乘数
     if op_position == 2:
         shapes = [shapes[1], shapes[0]]
         clients = [clients[1], clients[0]]
     u0 = np.random.uniform(-1, 1, shapes[0])
     u1 = np.random.uniform(-1, 1, shapes[0])
     v0 = np.random.uniform(-1, 1, shapes[1])
     v1 = np.random.uniform(-1, 1, shapes[1])
     z = np.matmul(u0 + u1, v0 + v1)
     z0 = z * np.random.uniform(0, 1, z.shape)
     z1 = z - z0
     try:
         self.send_check_msg(
             clients[0],
             PackedMessage(MessageType.Triplet_Array,
                           (clients[1], u0, v0, z0), clients[1]))
     except:
         self.logger.logE("Sending triplets to client %d failed" %
                          clients[0])
     try:
         self.send_check_msg(
             clients[1],
             PackedMessage(MessageType.Triplet_Array,
                           (clients[0], v1, u1, z1), clients[0]))
     except:
         self.logger.logE("Sending triplets to client %d failed" %
                          clients[1])
Пример #2
0
 def GetComputationData(self, request, context):
     msg = decode_ComputationData(request)
     if self.msg_handler(msg, request.client_id):
         return encode_ComputationData(
             PackedMessage(header=MessageType.RECEIVED_OK, data=None))
     else:
         return encode_ComputationData(
             PackedMessage(header=MessageType.RECEIVED_ERR,
                           data="Buffer occupied, send it later"))
    def multiply_BA_with(self, client_id: int, triplet_id: int,
                         client_shape: tuple, matB: np.ndarray):
        try:
            self.send_check_msg(
                triplet_id,
                PackedMessage(MessageType.Triplet_Set,
                              (2, client_id, matB.shape, client_shape)))
            triplet_msg = self.receive_check_msg(triplet_id,
                                                 MessageType.Triplet_Array,
                                                 key=client_id).data
            shared_V_self, shared_U_self, shared_W_self = triplet_msg[1:]
        except:
            self.logger.logE(
                "Get triplet arrays failed. Stop multiplication with client %d."
                % client_id)
            return False
        try:
            shared_B_self = matB * np.random.uniform(0, 1, matB.shape)
            shared_B_other = matB - shared_B_self
            self.send_check_msg(
                client_id,
                PackedMessage(MessageType.MUL_Mat_Share, shared_B_other))
            shared_A_self = self.receive_check_msg(
                client_id, MessageType.MUL_Mat_Share).data

        except:
            self.logger.logE(
                "Swap mat shares with client. Stop multiplication with client %d."
                % client_id)
            return False
        try:
            shared_A_sub_U_self = shared_A_self - shared_U_self
            self.send_check_msg(
                client_id,
                PackedMessage(MessageType.MUL_AsubU_Share,
                              shared_A_sub_U_self))
            shared_A_sub_U_other = self.receive_check_msg(
                client_id, MessageType.MUL_AsubU_Share).data

            shared_B_sub_V_self = shared_B_self - shared_V_self
            self.send_check_msg(
                client_id,
                PackedMessage(MessageType.MUL_BsubV_Share,
                              shared_B_sub_V_self))
            shared_B_sub_V_other = self.receive_check_msg(
                client_id, MessageType.MUL_BsubV_Share).data
        except:
            self.logger.logE(
                "Swap A-U and B-V with client failed. Stop multiplication with client %d"
                % client_id)
            return False

        A_sub_U = shared_A_sub_U_self + shared_A_sub_U_other
        B_sub_V = shared_B_sub_V_self + shared_B_sub_V_other
        self.product = shared_U_self @ B_sub_V + A_sub_U @ shared_V_self + shared_W_self
        return True
Пример #4
0
    def compare_split_node(self):
        client_outs = self.broadcaster.receive_all(self.feature_client_ids,
                                                   MessageType.XGBOOST_GAIN)
        if self.broadcaster.error:
            self.logger.logE("Gather clients' outputs failed. Stop training")
            return False
        # send ack msg to the client which has the min loss
        selected = None
        for data_client in self.feature_client_ids:
            if selected is None or client_outs[selected] > client_outs[
                    data_client]:
                selected = data_client

        msgs = dict()
        for data_client in self.feature_client_ids:
            if data_client == selected:
                msgs[data_client] = PackedMessage(
                    MessageType.XGBOOST_SELECTED_NODE, 'ack')
            else:
                msgs[data_client] = PackedMessage(
                    MessageType.XGBOOST_SELECTED_NODE, 'rej')

        self.broadcaster.broadcast(self.feature_client_ids, msgs)
        if self.broadcaster.error:
            self.logger.logE("Broadcast split info failed. Stop training")
            return False

        if not self.get_residual(selected):
            return False

        preds = self.broadcaster.receive_all(self.feature_client_ids,
                                             MessageType.XGBOOST_PRED_LABEL)
        if self.broadcaster.error:
            self.logger.logE(
                "Get predict for each epoch failed. Stop Training")
            return False

        pred = np.zeros((self.raw_label_data.shape[0]))
        for k, v in preds.items():
            pred += v
        pred = pred.reshape(-1, 1)
        try:
            metric = self.metric_func(self.raw_label_data, pred)
            msg = PackedMessage(MessageType.XGBOOST_PRED_LABEL,
                                (self.label_data.mean(), metric))
            self.send_check_msg(self.label_client_id, msg)
        except:
            self.logger.logE(
                "Send predict label to label client fail. Stop Training")
            return False
        return True
Пример #5
0
 def send(self, receiver: int, msg: PackedMessage, time_out: float = None):
     """
     :param receiver: 接收方
     :param msg: 消息
     :param time_out: 发送的最大时长,超出该时长未发送成功则停止。如果未设置则采用默认延时。
     :return:
     """
     if not time_out:
         time_out = self.time_out
     resp = PackedMessage(MessageType.RECEIVED_ERR, None)
     start_send_time = time.time()
     while resp.header == MessageType.RECEIVED_ERR:
         if time.time() - start_send_time > time_out:
             self.logger.log(
                 "Timeout while sending to client %d. Time elapsed %.3f" %
                 (receiver, time_out))
             return resp
         try:
             resp = self.rpc_clients[receiver].sendComputationMessage(
                 msg, self.client_id, time_out)
         except Exception as e:
             self.logger.logE("Error while sending to client %d" %
                              receiver + ":" + str(e))
             return resp
     return resp
Пример #6
0
 def _broadcast_start(self, stop=False):
     if not stop:
         header = MessageType.XGBOOST_NEXT_TRAIN_ROUND
     else:
         header = MessageType.XGBOOST_TRAINING_STOP
     start_data = self.label_data
     self.broadcaster.broadcast(self.feature_client_ids + [self.label_client_id], PackedMessage(header, start_data))
    def _before_training(self):
        try:
            self._build_mlp_model(self.in_dim, self.layers)
        except:
            self.logger.logE("Build tensorflow model failed. Stop training.")
            return False

        self.client_dims = self.broadcaster.receive_all(
            self.feature_client_ids, MessageType.SharedNN_ClientDim)
        if self.broadcaster.error:
            self.logger.logE("Gather clients' dims failed. Stop training.")
            return False
        train_config = {
            "client_dims": self.client_dims,
            "out_dim": self.in_dim,
            "batch_size": self.batch_size,
            "test_batch_size": self.test_batch_size,
            "learning_rate": self.learning_rate
        }
        self.broadcaster.broadcast(
            self.feature_client_ids + [self.label_client_id],
            PackedMessage(MessageType.SharedNN_TrainConfig, train_config))
        if self.broadcaster.error:
            self.logger.logE(
                "Broadcast training config message failed. Stop training.")
            return False
        return True
Пример #8
0
    def predict(self):
        self.logger.log("MainClient predict start...")

        # get test label
        try:
            self.send_check_msg(self.label_client_id,
                                PackedMessage(MessageType.XGBOOST_TRAIN, False))
            y_true = self.receive_check_msg(self.label_client_id, MessageType.XGBOOST_LABEL).data
        except:
            self.logger.logE("Get test label from label client failed. Stop predict.")
            return False

        y_preds = np.zeros((y_true.shape[0]))


        y_pred_dict = self.broadcaster.receive_all(self.feature_client_ids, MessageType.XGBOOST_PRED_LABEL)


        if self.broadcaster.error:
            self.logger.logE("Gather predict y failed")
            return False

        for y_pred in y_pred_dict.values():
            y_preds[:] += y_pred

        # print(y_preds)
        y_preds = y_preds.reshape(-1, 1)
        res = self.metric_func(y_true, y_preds)

        self.logger.log("Predict metric {}".format(res))
Пример #9
0
    def _before_training(self):
        if not super(FeatureClient, self)._before_training():
            return False
        try:
            self.send_check_msg(
                self.main_client_id,
                PackedMessage(MessageType.SharedNN_ClientDim, self.data_dim))
            config = self.receive_check_msg(
                self.main_client_id, MessageType.SharedNN_TrainConfig).data
            self.logger.log(
                "Received main client's config message: {}".format(config))
            self.other_feature_dims = config["client_dims"]
            self.output_dim = config["out_dim"]
            self.batch_size = config["batch_size"]
            self.test_batch_size = config["test_batch_size"]
            self.learning_rate = config["learning_rate"]
        except:
            self.logger.logE(
                "Get training config from server failed. Stop training.")
            return False

        try:
            for client_id in self.other_feature_dims:
                other_dim = self.other_feature_dims[client_id]
                self.other_paras[client_id] = np.random.normal(
                    0, 1 / (len(self.feature_client_ids) * other_dim),
                    [other_dim, self.output_dim])
        except:
            self.logger.logE("Initialize weights failed. Stop training.")
            return False
        self.para = np.random.normal(
            0, 1 / (len(self.feature_client_ids) * self.data_dim),
            [self.data_dim, self.output_dim])

        return True
Пример #10
0
 def send_gain_to(client_id: int, update: float):
     try:
         self.send_check_msg(
             client_id, PackedMessage(MessageType.XGBOOST_GAIN, update))
     except:
         self.logger.logE(
             "Error encountered while sending gain to other client")
         self.error = True
Пример #11
0
 def __send_aligned_ids_to(self, client_id):
     try:
         self.send_check_msg(
             client_id,
             PackedMessage(MessageType.ALIGN_FINAL_IDS, self.aligned_ids))
     except:
         self.logger.logE("Sending aligned ids to client %d failed" %
                          client_id)
         self.error = True
Пример #12
0
 def send_pred(server_id, update):
     try:
         self.send_check_msg(
             server_id,
             PackedMessage(MessageType.XGBOOST_PRED_LABEL, update))
     except:
         self.logger.logE(
             "Error encountered while sending label to other client")
         self.error = True
Пример #13
0
 def __send_aes_key_to(self, client_id):
     try:
         self.send_check_msg(
             client_id,
             PackedMessage(MessageType.ALIGN_AES_KEY,
                           (self.shared_aes_key, self.shared_aes_iv)))
     except:
         self.logger.logE("Send raw aes key to client %d failed." %
                          client_id)
         self.error = True
Пример #14
0
 def _get_raw_abel_data(self):
     try:
         self.send_check_msg(self.label_client_id,
                             PackedMessage(MessageType.XGBOOST_TRAIN, True))
         self.label_data = self.receive_check_msg(self.label_client_id, MessageType.XGBOOST_LABEL).data
         self.raw_label_data = deepcopy(self.label_data)
     except:
         self.logger.logE("Receive label from label client failed. Stop training.")
         return False
     return True
Пример #15
0
    def _backward(self):
        try:
            grad_on_output, status = self.receive_check_msg(
                self.main_client_id,
                MessageType.SharedNN_FeatureClientGrad).data
        except:
            self.logger.logE(
                "Receive grads message from main client failed. Stop training."
            )
            return False
        if grad_on_output is None:
            pass
        else:
            own_para_grad = self.batch_data.transpose() @ grad_on_output
            portion = np.random.uniform(0, 1, len(self.feature_client_ids))
            portion /= np.sum(portion)
            self.para -= self.learning_rate * own_para_grad * portion[0]
            other_grad_msgs = dict()
            current_portion = 1
            for other_id in self.other_feature_client_ids:
                other_grad_msgs[other_id] = PackedMessage(
                    MessageType.SharedNN_FeatureClientParaGrad,
                    own_para_grad * portion[current_portion])
                current_portion += 1

            try:
                self.broadcaster.broadcast(self.other_feature_client_ids,
                                           other_grad_msgs)
                self_para_grads = self.broadcaster.receive_all(
                    self.other_feature_client_ids,
                    MessageType.SharedNN_FeatureClientParaGrad)
                for other_id in self_para_grads:
                    self.other_paras[
                        other_id] -= self.learning_rate * self_para_grads[
                            other_id]
            except:
                self.logger.logE(
                    "Swap parameter gradients with other clients failed. Stop training."
                )
                return False

        if status == "Stop":
            self.logger.log(
                "Received Stop signal from main-client, stop training.")
            self.finished = True
        elif status == "Continue-Test":
            self.mpc_mode = MPCC.ClientMode.Test
        else:
            self.mpc_mode = MPCC.ClientMode.Train

        return True
Пример #16
0
    def send_label_to_main(self):
        try:
            header = MessageType.XGBOOST_LABEL

            train = self.receive_check_msg(self.main_client_id, MessageType.XGBOOST_TRAIN).data
            if train:
                data = self.label
            else:
                data = self.test_label
            self.send_check_msg(self.main_client_id,
                                PackedMessage(header, data))
        except:
            self.logger.logE("Send Label to main client predictions failed. Stop training.")
            return False

        return True
Пример #17
0
    def _before_training(self):
        # send config dict to every data client
        self.broadcaster.broadcast(self.feature_client_ids + [self.label_client_id],
                                   PackedMessage(MessageType.XGBOOST_TRAIN_CONFIG, self.config
        ))

        if self.broadcaster.error:
            self.logger.logE("Broadcast training config message failed. Stop training.")
            return False

        #
        # if not self._get_raw_abel_data():
        #     return False
        # self.broadcaster.receive_all(self.feature_client_ids, MessageType.CLIENT_READY)


        return True
Пример #18
0
    def _compute_loss(self):
        try:
            preds, mode = self.receive_check_msg(
                self.main_client_id, MessageType.SharedNN_MainClientOut).data
            if mode == "Train" or mode == "Train-Stop":
                self.mpc_mode = MPCC.ClientMode.Train
                self.batch_data = self.train_data_loader.get_batch(
                    self.batch_size)
            elif mode == "Test" or mode == "Test-Stop":
                self.mpc_mode = MPCC.ClientMode.Test
                self.logger.log(
                    "Received Test signal. Load test set data to batch.")
                self.batch_data = self.test_data_loader.get_batch(
                    self.test_batch_size)
            if mode[-4:] == "Stop":
                self.logger.log(
                    "Received Stop signal. Stop training after finish this round."
                )
                self.finished = True

            loss = self.loss_func.forward(self.batch_data, preds)
            metric = self.metric_func(self.batch_data, preds)
            if self.mpc_mode is MPCC.ClientMode.Test:
                self.test_record.append(
                    [time.time() - self.start_time, self.n_rounds, loss] +
                    metric)
            self.logger.log(
                "Current batch: {} loss: {}, metric value: {}".format(
                    self.n_rounds, loss, metric))
            grad = self.loss_func.backward()
            self.send_check_msg(
                self.main_client_id,
                PackedMessage(MessageType.SharedNN_MainClientGradLoss,
                              (grad, loss)))

        except:
            self.logger.logE(
                "Compute gradient for main client predictions failed. Stop training."
            )
            return False

        return True
 def _compute_loss(self):
     try:
         mode_str = "Train"
         if self.mpc_mode is MPCC.ClientMode.Test:
             mode_str = "Test"
         if self.n_rounds == self.max_iter:
             self.finished = True
             mode_str += "-Stop"
         self.send_check_msg(
             self.label_client_id,
             PackedMessage(MessageType.SharedNN_MainClientOut,
                           (self.network_out.numpy(), mode_str)))
         self.grad_on_output = self.receive_check_msg(
             self.label_client_id,
             MessageType.SharedNN_MainClientGradLoss).data[0]
     except:
         self.logger.logE(
             "Get gradients from label client failed. Stop training.")
         return False
     return True
Пример #20
0
 def _before_training(self):
     self.logger.log("Start sync random seed with other data-providers.")
     random_seed = np.random.randint(0, 999999999)
     other_data_client_ids = self.feature_client_ids + [
         self.label_client_id
     ]
     other_data_client_ids.remove(self.client_id)
     self.broadcaster.broadcast(
         other_data_client_ids,
         PackedMessage(MessageType.SharedNN_RandomSeed, random_seed))
     all_seeds = self.broadcaster.receive_all(
         other_data_client_ids, MessageType.SharedNN_RandomSeed)
     if not self.broadcaster.error:
         for c in all_seeds:
             random_seed ^= all_seeds[c]
         self.train_data_loader.set_random_seed(random_seed)
         self.test_data_loader.set_random_seed(random_seed)
         self.logger.log("Random seed swapped and set to data loaders.")
         return True
     else:
         self.logger.log(
             "Swapping random seed with other data-providers failed. Stop training."
         )
         return False
    def _backward(self):
        if self.mpc_mode is MPCC.ClientMode.Train:
            grad_on_output = self.grad_on_output
            if len(self.network.trainable_variables) != 0:
                model_jacobians = self.gradient_tape.jacobian(
                    self.network_out, self.network.trainable_variables)
                model_grad = [
                    tf.reduce_sum(
                        model_jacobian *
                        (tf.reshape(
                            grad_on_output.astype(np.float32),
                            list(grad_on_output.shape) +
                            [1
                             for _ in range(len(model_jacobian.shape) - 2)]) +
                         tf.zeros_like(model_jacobian,
                                       dtype=model_jacobian.dtype)),
                        axis=[0, 1]) for model_jacobian in model_jacobians
                ]
                self.optimizer.apply_gradients(
                    zip(model_grad, self.network.trainable_variables))
            input_jacobian = self.gradient_tape.jacobian(
                self.network_out, self.input_tensor)
            input_grad = tf.reduce_sum(
                input_jacobian *
                (tf.reshape(
                    grad_on_output.astype(np.float32),
                    list(grad_on_output.shape) +
                    [1 for i in range(len(input_jacobian.shape) - 2)]) +
                 tf.zeros_like(self.input_tensor,
                               dtype=self.input_tensor.dtype)),
                axis=[0, 1]).numpy()
        else:
            input_grad = None

        if self.n_rounds == self.max_iter:
            self.broadcaster.broadcast(
                self.feature_client_ids,
                PackedMessage(MessageType.SharedNN_FeatureClientGrad,
                              (input_grad, "Stop")))
            try:
                self.send_check_msg(
                    self.crypto_producer_id,
                    PackedMessage(MessageType.Common_Stop, None, key="Stop"))
            except:
                self.logger.logW(
                    "Send stop message to triplet provider failed.")

        elif (self.n_rounds + 1) % self.test_per_batches == 0:
            self.mpc_mode = MPCC.ClientMode.Test
            self.broadcaster.broadcast(
                self.feature_client_ids,
                PackedMessage(MessageType.SharedNN_FeatureClientGrad,
                              (input_grad, "Continue-Test")))
        else:
            self.mpc_mode = MPCC.ClientMode.Train
            self.broadcaster.broadcast(
                self.feature_client_ids,
                PackedMessage(MessageType.SharedNN_FeatureClientGrad,
                              (input_grad, "Continue-Train")))

        if self.broadcaster.error:
            self.logger.logE(
                "Broadcast feature client's grads failed. Stop training.")
            return False

        return True
Пример #22
0
    def start_align(self):
        self.logger.log("Start generating random keys and ivs")
        self.generate_raw_key_and_iv()
        self.__send_aes_keys()
        if self.error:
            self.logger.logE("Send aes keys failed, stop align")
            return False

        self.__receive_aes_keys()
        if self.error:
            self.logger.logE("Receive aes keys failed, stop align")
            return False

        self.__generate_shared_key()
        self.logger.log("Shared key generated. Start encrypt ids.")
        try:
            self.__load_and_enc_data()
        except:
            self.logger.logE(
                "Error while loading and encrypting data. Stop align")
            return False
        self.logger.log("Start sending encrypted ids to main preprocessor")
        try:
            self.send_check_msg(
                self.main_client_id,
                PackedMessage(MessageType.ALIGN_ENC_IDS, self.sample_ids))
        except:
            self.logger.logE(
                "Error while sending aligned ids to align server. Stop align")
            return False

        try:
            msg = self.receive_check_msg(self.main_client_id,
                                         MessageType.ALIGN_FINAL_IDS)
        except:
            self.logger.logE(
                "Error while receiving ids intersection from align_server")
            return False

        self.logger.log("Received aligned ids. Start making aligned data.")
        encrypted_ids = msg.data
        selected_ids = list()
        for sample_id in encrypted_ids:
            selected_ids.append(
                AES.new(self.shared_aes_key, AES.MODE_CBC,
                        self.shared_aes_iv).decrypt(sample_id).decode(
                            'utf8').replace('\0', ''))

        aligned_data = self.data.loc[selected_ids]

        file_name = pathlib.Path(self.filepath).name

        self.out_indexed_file = self.out_dir + file_name[:-4] + "_aligned.csv"
        aligned_data.to_csv(self.out_indexed_file, index=True)
        test_size = int(len(selected_ids) / 5)
        train_data = aligned_data.iloc[:-test_size]
        test_data = aligned_data.iloc[-test_size:]

        self.train_data_path = self.out_dir + "train.csv"
        self.test_data_path = self.out_dir + "test.csv"
        if self.cols is None:
            self.cols = self.data.columns
        if isinstance(self.cols[0], int):
            self.cols = self.data.columns[self.cols]
        if not set(self.cols) <= set(train_data.columns):
            self.logger.logE("Selected columns not in the csv.")
            return False
        train_data[self.cols].to_csv(self.train_data_path,
                                     header=False,
                                     index=False)
        test_data[self.cols].to_csv(self.test_data_path,
                                    header=False,
                                    index=False)

        self.logger.log("Align finished, aligned file saved in " +
                        self.out_dir)
        return True
Пример #23
0
    def _local_train_one_round(self):
        """本地调用xgb计算一轮,发送gain信息"""

        # todo 本地调用xgb计算一轮,发送gain信息

        # send gain
        def send_gain_to(client_id: int, update: float):
            try:
                self.send_check_msg(client_id, PackedMessage(MessageType.XGBOOST_GAIN, update))
            except:
                self.logger.logE("Error encountered while sending gain to other client")
                self.error = True

        # recevie select_node
        select_node = None

        def receive_node(client_id: int):
            try:
                nonlocal select_node
                select_node = self.receive_check_msg(client_id, MessageType.XGBOOST_SELECTED_NODE).data
            except:
                self.logger.logE("Error encountered while receiving selected info from client %d" % client_id)
                self.error = True

        X_train = self.data
        X_test = self.test_data
        y_train = self.label_data
        y = self.raw_label

        # self.mean = np.mean(y)
        # y_pred = np.ones_like(y) * self.mean

        # calculate loss
        if self.raw_label.all() == self.label_data.all():
            loss = LogLoss(self.raw_label, np.ones_like(self.raw_label) * np.mean(self.raw_label))
        else:
            loss = LogLoss(self.raw_label, self.raw_label - self.label_data)
        g, h = loss.g(), loss.h()

        # print('h:', h)
        estimator_t = CART(reg_lambda=self.reg_lambda, max_depth=self.max_depth, gamma=self.gamma,
                           col_sample_ratio=self.col_sample_ratio, row_sample_ratio=self.row_sample_ratio)
        estimator_t.fit(X_train, self.label_data, g, h)
        # print(y_pred.shape)
        # print(np.expand_dims(estimator_t.predict(X_train), axis=1).shape)
        label = self.label_data - (self.learning_rate * np.expand_dims(estimator_t.predict(X_train), axis=1))
        # print(y_pred)
        # the smaller the better
        # print(estimator_t.obj_val)

        # tinyxgb_clf = XGBClassifier(self.configs)
        # print(y_train)

        # tinyxgb_clf.fit(X_train, y_train)
        gain = estimator_t.obj_val
        sending_th = threading.Thread(target=send_gain_to, args=(self.main_client_id, gain,))
        sending_th.start()
        sending_th.join()

        receving_th = threading.Thread(target=receive_node, args=(self.main_client_id,))
        receving_th.start()
        receving_th.join()

        # label = y_train - y_pred
        # print('{} selected: {}'.format(self.client_id, select_node))

        if select_node == 'ack':
            print('epoch ---- loss', np.mean(loss.forward()), '----client_id', self.client_id)
            self.estimators.append(estimator_t)
            # predict
            self.send_check_msg(self.main_client_id, PackedMessage(MessageType.XGBOOST_RESIDUAL, label))
Пример #24
0
def decode_ComputationData(computation_data: message_pb2.ComputationData):
    data, key = pickle.loads(computation_data.python_bytes)
    return PackedMessage(MessageType(computation_data.type), data, key)
Пример #25
0
    def _forward(self):
        if self.mpc_mode is MPCC.ClientMode.Train:
            self.batch_data = self.train_data_loader.get_batch(self.batch_size)
        else:
            self.logger.log("Test round: load data from test dataset")
            self.batch_data = self.test_data_loader.get_batch(
                self.test_batch_size)

        def thread_mul_with_client(client_id: int):
            multiplier = self.multipliers[client_id]
            if self.client_id < client_id:
                if not multiplier.multiply_AB_with(
                        client_id, self.crypto_producer_id,
                    (self.data_dim, self.output_dim), self.batch_data):
                    self.error = True
                    return
                self.shared_out_AB[client_id] = multiplier.product
                if not multiplier.multiply_BA_with(
                        client_id, self.crypto_producer_id,
                    (self.batch_data.shape[0],
                     self.other_feature_dims[client_id]),
                        self.other_paras[client_id]):
                    self.error = True
                    return
                self.shared_out_BA[client_id] = multiplier.product
            else:
                if not multiplier.multiply_BA_with(
                        client_id, self.crypto_producer_id,
                    (self.batch_data.shape[0],
                     self.other_feature_dims[client_id]),
                        self.other_paras[client_id]):
                    self.error = True
                    return
                self.shared_out_BA[client_id] = multiplier.product
                if not multiplier.multiply_AB_with(
                        client_id, self.crypto_producer_id,
                    (self.data_dim, self.output_dim), self.batch_data):
                    self.error = True
                    return
                self.shared_out_AB[client_id] = multiplier.product

        mul_threads = []
        for client_id in self.other_feature_client_ids:
            mul_threads.append(
                threading.Thread(target=thread_mul_with_client,
                                 args=(client_id, )))
            mul_threads[-1].start()
        for mul_thread in mul_threads:
            mul_thread.join()
        if self.error:
            self.logger.logE(
                "Multiplication on shared parameters with other clients failed. Stop training."
            )
            return False

        self.own_out = self.batch_data @ self.para

        try:
            self.send_check_msg(
                self.main_client_id,
                PackedMessage(
                    MessageType.SharedNN_FeatureClientOut,
                    (self.own_out, self.shared_out_AB, self.shared_out_BA)))
        except:
            self.logger.logE(
                "Send outputs to main client failed. Stop training.")
            return False

        return True