Example #1
0
class HeteroLRGuest(HeteroLRBase):
    def __init__(self):
        super().__init__()
        self.encrypted_error = None
        self.encrypted_wx = None
        self.z_square = None
        self.wx_self = None
        self.wx_remote = None

    def _cal_z_in_share(self, w_self, w_remote, features, suffix, cipher):
        z1 = features.dot_local(w_self)

        za_suffix = ("za", ) + suffix

        za_share = self.secure_matrix_obj.secure_matrix_mul(
            w_remote,
            tensor_name=".".join(za_suffix),
            cipher=cipher,
            suffix=za_suffix)
        zb_suffix = ("zb", ) + suffix
        zb_share = self.secure_matrix_obj.secure_matrix_mul(
            features,
            tensor_name=".".join(zb_suffix),
            cipher=None,
            suffix=zb_suffix)

        z = z1 + za_share + zb_share
        return z

    def _compute_sigmoid(self, z, remote_z):
        complete_z = z + remote_z

        sigmoid_z = complete_z * 0.25 + 0.5

        return sigmoid_z

    def forward(self, weights, features, suffix, cipher):
        if not self.reveal_every_iter:
            LOGGER.info(f"[forward]: Calculate z in share...")
            w_self, w_remote = weights
            z = self._cal_z_in_share(w_self, w_remote, features, suffix,
                                     cipher)
        else:
            LOGGER.info(f"[forward]: Calculate z directly...")
            w = weights.unboxed
            z = features.dot_local(w)

        remote_z = self.secure_matrix_obj.share_encrypted_matrix(
            suffix=suffix, is_remote=False, cipher=None, z=None)[0]

        self.wx_self = z
        self.wx_remote = remote_z

        sigmoid_z = self._compute_sigmoid(z, remote_z)

        self.encrypted_wx = self.wx_self + self.wx_remote

        self.encrypted_error = sigmoid_z - self.labels

        tensor_name = ".".join(("sigmoid_z", ) + suffix)
        shared_sigmoid_z = SecureMatrix.from_source(tensor_name, sigmoid_z,
                                                    cipher,
                                                    self.fixedpoint_encoder.n,
                                                    self.fixedpoint_encoder)
        return shared_sigmoid_z

    def backward(self, error, features, suffix, cipher):
        LOGGER.info(f"[backward]: Calculate gradient...")
        batch_num = self.batch_num[int(suffix[1])]
        error_1_n = error * (1 / batch_num)

        ga2_suffix = ("ga2", ) + suffix
        ga2_2 = self.secure_matrix_obj.secure_matrix_mul(
            error_1_n,
            tensor_name=".".join(ga2_suffix),
            cipher=cipher,
            suffix=ga2_suffix,
            is_fixedpoint_table=False)

        # LOGGER.debug(f"ga2_2: {ga2_2}")

        encrypt_g = self.encrypted_error.dot(features) * (1 / batch_num)

        # LOGGER.debug(f"encrypt_g: {encrypt_g}")

        tensor_name = ".".join(("encrypt_g", ) + suffix)
        gb2 = SecureMatrix.from_source(tensor_name, encrypt_g, self.cipher,
                                       self.fixedpoint_encoder.n,
                                       self.fixedpoint_encoder)

        # LOGGER.debug(f"gb2: {gb2}")

        return gb2, ga2_2

    def compute_loss(self, weights, suffix, cipher=None):
        """
          Use Taylor series expand log loss:
          Loss = - y * log(h(x)) - (1-y) * log(1 - h(x)) where h(x) = 1/(1+exp(-wx))
          Then loss' = - (1/N)*∑(log(1/2) - 1/2*wx + ywx -1/8(wx)^2)
        """
        LOGGER.info(f"[compute_loss]: Calculate loss ...")
        wx = (-0.5 * self.encrypted_wx).reduce(operator.add)
        ywx = (self.encrypted_wx * self.labels).reduce(operator.add)

        wx_square = (2 * self.wx_remote * self.wx_self).reduce(operator.add) + \
                    (self.wx_self * self.wx_self).reduce(operator.add)

        wx_remote_square = self.secure_matrix_obj.share_encrypted_matrix(
            suffix=suffix, is_remote=False, cipher=None,
            wx_self_square=None)[0]

        wx_square = (wx_remote_square + wx_square) * -0.125

        batch_num = self.batch_num[int(suffix[2])]
        loss = (wx + ywx + wx_square) * (-1 / batch_num) - np.log(0.5)

        tensor_name = ".".join(("shared_loss", ) + suffix)
        share_loss = SecureMatrix.from_source(
            tensor_name=tensor_name,
            source=loss,
            cipher=None,
            q_field=self.fixedpoint_encoder.n,
            encoder=self.fixedpoint_encoder)

        tensor_name = ".".join(("loss", ) + suffix)
        loss = share_loss.get(tensor_name=tensor_name, broadcast=False)[0]

        if self.reveal_every_iter:
            loss_norm = self.optimizer.loss_norm(weights)
            if loss_norm:
                loss += loss_norm
        else:
            if self.optimizer.penalty == consts.L2_PENALTY:
                w_self, w_remote = weights

                w_encode = np.hstack((w_remote.value, w_self.value))

                w_encode = np.array([w_encode])

                w_tensor_name = ".".join(("loss_norm_w", ) + suffix)
                w_tensor = fixedpoint_numpy.FixedPointTensor(
                    value=w_encode,
                    q_field=self.fixedpoint_encoder.n,
                    endec=self.fixedpoint_encoder,
                    tensor_name=w_tensor_name)

                w_tensor_transpose_name = ".".join(
                    ("loss_norm_w_transpose", ) + suffix)
                w_tensor_transpose = fixedpoint_numpy.FixedPointTensor(
                    value=w_encode.T,
                    q_field=self.fixedpoint_encoder.n,
                    endec=self.fixedpoint_encoder,
                    tensor_name=w_tensor_transpose_name)

                loss_norm_tensor_name = ".".join(("loss_norm", ) + suffix)

                loss_norm = w_tensor.dot(
                    w_tensor_transpose,
                    target_name=loss_norm_tensor_name).get(broadcast=False)
                loss_norm = 0.5 * self.optimizer.alpha * loss_norm[0][0]
                loss = loss + loss_norm

        LOGGER.info(
            f"[compute_loss]: loss={loss}, reveal_every_iter={self.reveal_every_iter}"
        )

        return loss

    def _reveal_every_iter_weights_check(self, last_w, new_w, suffix):
        square_sum = np.sum((last_w - new_w)**2)
        host_sums = self.converge_transfer_variable.square_sum.get(
            suffix=suffix)
        for hs in host_sums:
            square_sum += hs
        weight_diff = np.sqrt(square_sum)
        is_converge = False
        if weight_diff < self.model_param.tol:
            is_converge = True
        LOGGER.info(f"n_iter: {self.n_iter_}, weight_diff: {weight_diff}")
        self.converge_transfer_variable.converge_info.remote(is_converge,
                                                             role=consts.HOST,
                                                             suffix=suffix)
        return is_converge

    @assert_io_num_rows_equal
    def predict(self, data_instances):
        """
        Prediction of lr
        Parameters
        ----------
        data_instances: Table of Instance, input data

        Returns
        ----------
        Table
            include input data label, predict probably, label
        """
        self._abnormal_detection(data_instances)
        data_instances = self.align_data_header(data_instances, self.header)
        if self.need_one_vs_rest:
            predict_result = self.one_vs_rest_obj.predict(data_instances)
            return predict_result
        LOGGER.debug(
            f"Before_predict_reveal_strategy: {self.model_param.reveal_strategy}, {self.is_respectively_reveal}"
        )

        def _vec_dot(v, coef, intercept):
            return fate_operator.vec_dot(v.features, coef) + intercept

        f = functools.partial(_vec_dot,
                              coef=self.model_weights.coef_,
                              intercept=self.model_weights.intercept_)

        pred_prob = data_instances.mapValues(f)
        host_probs = self.transfer_variable.host_prob.get(idx=-1)

        LOGGER.info("Get probability from Host")

        # guest probability
        for host_prob in host_probs:
            if not self.is_respectively_reveal:
                host_prob = self.cipher.distribute_decrypt(host_prob)
            pred_prob = pred_prob.join(host_prob, lambda g, h: g + h)
        pred_prob = pred_prob.mapValues(lambda p: activation.sigmoid(p))
        threshold = self.model_param.predict_param.threshold
        predict_result = self.predict_score_to_output(data_instances,
                                                      pred_prob,
                                                      classes=[0, 1],
                                                      threshold=threshold)

        return predict_result

    def _get_param(self):
        if self.need_cv:
            param_protobuf_obj = lr_model_param_pb2.LRModelParam()
            return param_protobuf_obj

        if self.need_one_vs_rest:
            one_vs_rest_result = self.one_vs_rest_obj.save(
                lr_model_param_pb2.SingleModel)
            single_result = {
                'header': self.header,
                'need_one_vs_rest': True,
                "best_iteration": -1
            }
        else:
            one_vs_rest_result = None
            single_result = self.get_single_model_param()

            single_result['need_one_vs_rest'] = False
        single_result['one_vs_rest_result'] = one_vs_rest_result
        LOGGER.debug(f"saved_model: {single_result}")
        param_protobuf_obj = lr_model_param_pb2.LRModelParam(**single_result)
        return param_protobuf_obj

    def get_single_model_param(self, model_weights=None, header=None):
        result = super().get_single_model_param(model_weights, header)
        if not self.is_respectively_reveal:
            result["cipher"] = dict(
                public_key=dict(n=str(self.cipher.public_key.n)),
                private_key=dict(p=str(self.cipher.privacy_key.p),
                                 q=str(self.cipher.privacy_key.q)))
        return result

    def load_single_model(self, single_model_obj):
        super(HeteroLRGuest, self).load_single_model(single_model_obj)
        if not self.is_respectively_reveal:
            cipher_info = single_model_obj.cipher
            self.cipher = PaillierEncrypt()
            public_key = PaillierPublicKey(int(cipher_info.public_key.n))
            privacy_key = PaillierPrivateKey(public_key,
                                             int(cipher_info.private_key.p),
                                             int(cipher_info.private_key.q))
            self.cipher.set_public_key(public_key=public_key)
            self.cipher.set_privacy_key(privacy_key=privacy_key)

    def get_model_summary(self):
        summary = super(HeteroLRGuest, self).get_model_summary()
        return summary
class HomoLRGuest(HomoLRBase):
    def __init__(self):
        super(HomoLRGuest, self).__init__()
        self.gradient_operator = LogisticGradient()
        self.loss_history = []
        self.role = consts.GUEST
        self.aggregator = aggregator.Guest()

        self.zcl_encrypt_operator = PaillierEncrypt()

    def _init_model(self, params):
        super()._init_model(params)

    def fit(self, data_instances, validate_data=None):

        self._abnormal_detection(data_instances)
        self.init_schema(data_instances)

        validation_strategy = self.init_validation_strategy(
            data_instances, validate_data)
        self.model_weights = self._init_model_variables(data_instances)

        max_iter = self.max_iter
        total_data_num = data_instances.count()
        mini_batch_obj = MiniBatch(data_inst=data_instances,
                                   batch_size=self.batch_size)
        model_weights = self.model_weights

        self.__synchronize_encryption()
        self.zcl_idx, self.zcl_num_party = self.transfer_variable.num_party.get(
            idx=0, suffix=('train', ))
        LOGGER.debug("party num:" + str(self.zcl_num_party))
        self.__init_model()

        self.train_loss_results = []
        self.train_accuracy_results = []
        self.test_loss_results = []
        self.test_accuracy_results = []

        for iter_num in range(self.max_iter):
            total_loss = 0
            batch_num = 0
            epoch_train_loss_avg = tfe.metrics.Mean()
            epoch_train_accuracy = tfe.metrics.Accuracy()

            for train_x, train_y in self.zcl_dataset:
                LOGGER.info("Staring batch {}".format(batch_num))
                start_t = time.time()
                loss_value, grads = self.__grad(self.zcl_model, train_x,
                                                train_y)
                loss_value = loss_value.numpy()
                grads = [x.numpy() for x in grads]
                LOGGER.info("Start encrypting")
                loss_value = batch_encryption.encrypt(
                    self.zcl_encrypt_operator.get_public_key(), loss_value)
                grads = [
                    batch_encryption.encrypt_matrix(
                        self.zcl_encrypt_operator.get_public_key(), x)
                    for x in grads
                ]
                grads = Gradients(grads)
                LOGGER.info("Finish encrypting")
                # grads = self.encrypt_operator.get_public_key()
                self.transfer_variable.guest_grad.remote(
                    obj=grads.for_remote(),
                    role=consts.ARBITER,
                    idx=0,
                    suffix=(iter_num, batch_num))
                LOGGER.info("Sent grads")
                self.transfer_variable.guest_loss.remote(obj=loss_value,
                                                         role=consts.ARBITER,
                                                         idx=0,
                                                         suffix=(iter_num,
                                                                 batch_num))
                LOGGER.info("Sent loss")

                sum_grads = self.transfer_variable.aggregated_grad.get(
                    idx=0, suffix=(iter_num, batch_num))
                LOGGER.info("Got grads")
                sum_loss = self.transfer_variable.aggregated_loss.get(
                    idx=0, suffix=(iter_num, batch_num))
                LOGGER.info("Got loss")

                sum_loss = batch_encryption.decrypt(
                    self.zcl_encrypt_operator.get_privacy_key(), sum_loss)
                sum_grads = [
                    batch_encryption.decrypt_matrix(
                        self.zcl_encrypt_operator.get_privacy_key(),
                        x).astype(np.float32) for x in sum_grads.unboxed
                ]
                LOGGER.info("Finish decrypting")

                # sum_grads = np.array(sum_grads) / self.zcl_num_party

                self.zcl_optimizer.apply_gradients(
                    zip(sum_grads, self.zcl_model.trainable_variables),
                    self.zcl_global_step)

                elapsed_time = time.time() - start_t
                # epoch_train_loss_avg(loss_value)
                # epoch_train_accuracy(tf.argmax(self.zcl_model(train_x), axis=1, output_type=tf.int32),
                #                      train_y)
                self.train_loss_results.append(sum_loss)
                train_accuracy_v = accuracy_score(
                    train_y,
                    tf.argmax(self.zcl_model(train_x),
                              axis=1,
                              output_type=tf.int32))
                self.train_accuracy_results.append(train_accuracy_v)
                test_loss_v = self.__loss(self.zcl_model, self.zcl_x_test,
                                          self.zcl_y_test)
                self.test_loss_results.append(test_loss_v)
                test_accuracy_v = accuracy_score(
                    self.zcl_y_test,
                    tf.argmax(self.zcl_model(self.zcl_x_test),
                              axis=1,
                              output_type=tf.int32))
                self.test_accuracy_results.append(test_accuracy_v)

                LOGGER.info(
                    "Epoch {:03d}, iteration {:03d}: train_loss: {:.3f}, train_accuracy: {:.3%}, test_loss: {:.3f}, "
                    "test_accuracy: {:.3%}, elapsed_time: {:.4f}".format(
                        iter_num, batch_num, sum_loss, train_accuracy_v,
                        test_loss_v, test_accuracy_v, elapsed_time))

                batch_num += 1

                if batch_num >= self.zcl_early_stop_batch:
                    return

            self.n_iter_ = iter_num

    def __synchronize_encryption(self, mode='train'):
        """
        Communicate with hosts. Specify whether use encryption or not and transfer the public keys.
        """
        pub_key = self.transfer_variable.paillier_pubkey.get(idx=0,
                                                             suffix=(mode, ))
        LOGGER.debug("Received pubkey")
        self.zcl_encrypt_operator.set_public_key(pub_key)
        pri_key = self.transfer_variable.paillier_prikey.get(idx=0,
                                                             suffix=(mode, ))
        LOGGER.debug("Received prikey")
        self.zcl_encrypt_operator.set_privacy_key(pri_key)

    def __init_model(self):
        # self.zcl_model = keras.Sequential([
        #     keras.layers.Flatten(input_shape=(28, 28)),
        #     keras.layers.Dense(128, activation=tf.nn.relu),
        #     keras.layers.Dense(10, activation=tf.nn.softmax)
        # ])
        #
        json_file = open(MODEL_JSON_DIR, 'r')
        loaded_model_json = json_file.read()
        json_file.close()
        loaded_model = keras.models.model_from_json(loaded_model_json)
        loaded_model.load_weights(MODEL_WEIGHT_DIR)
        self.zcl_model = loaded_model
        LOGGER.info("Initialed model")

        # The data, split between train and test sets:
        (x_train,
         y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()

        x_train = x_train.astype('float32')
        x_test = x_test.astype('float32')
        x_train /= 255.0
        x_test /= 255.0
        y_train = y_train.squeeze().astype(np.int32)
        y_test = y_test.squeeze().astype(np.int32)

        avg_length = int(len(x_train) / self.zcl_num_party)
        split_idx = [_ * avg_length for _ in range(1, self.zcl_num_party)]
        x_train = np.split(x_train, split_idx)[self.zcl_idx]
        y_train = np.split(y_train, split_idx)[self.zcl_idx]

        train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
        BATCH_SIZE = 128
        SHUFFLE_BUFFER_SIZE = 1000
        train_dataset = train_dataset.shuffle(
            SHUFFLE_BUFFER_SIZE,
            reshuffle_each_iteration=True).batch(BATCH_SIZE)
        self.zcl_dataset = train_dataset
        self.zcl_x_test = x_test
        self.zcl_y_test = y_test

        self.zcl_cce = tf.keras.losses.SparseCategoricalCrossentropy()
        self.zcl_optimizer = tf.train.AdamOptimizer(
            learning_rate=LEARNING_RATE)
        self.zcl_global_step = tf.Variable(0)

    def __loss(self, model, x, y):
        y_ = model(x)
        return self.zcl_cce(y_true=y, y_pred=y_)

    def __grad(self, model, inputs, targets):
        with tf.GradientTape() as tape:
            loss_value = self.__loss(model, inputs, targets)
        return loss_value, tape.gradient(loss_value, model.trainable_variables)

    def __clip_gradients(self, grads, min_v, max_v):
        results = [tf.clip_by_value(t, min_v, max_v).numpy() for t in grads]
        return results

    def predict(self, data_instances):
        self._abnormal_detection(data_instances)
        self.init_schema(data_instances)
        predict_wx = self.compute_wx(data_instances, self.model_weights.coef_,
                                     self.model_weights.intercept_)

        pred_table = self.classify(predict_wx,
                                   self.model_param.predict_param.threshold)

        predict_result = data_instances.mapValues(lambda x: x.label)
        predict_result = pred_table.join(
            predict_result,
            lambda x, y: [y, x[1], x[0], {
                "1": x[0],
                "0": 1 - x[0]
            }])
        return predict_result
Example #3
0
class HomoLRHost(HomoLRBase):
    def __init__(self):
        super(HomoLRHost, self).__init__()
        self.gradient_operator = None
        self.loss_history = []
        self.is_converged = False
        self.role = consts.HOST
        self.aggregator = aggregator.Host()
        self.model_weights = None
        self.cipher = paillier_cipher.Host()

        self.zcl_encrypt_operator = PaillierEncrypt()

    def _init_model(self, params):
        super()._init_model(params)
        self.cipher.register_paillier_cipher(self.transfer_variable)
        if params.encrypt_param.method in [consts.PAILLIER]:
            self.use_encrypt = True
            self.gradient_operator = TaylorLogisticGradient()
            self.re_encrypt_batches = params.re_encrypt_batches
        else:
            self.use_encrypt = False
            self.gradient_operator = LogisticGradient()

    def fit(self, data_instances, validate_data=None):
        LOGGER.debug("Start data count: {}".format(data_instances.count()))

        self._abnormal_detection(data_instances)
        self.init_schema(data_instances)
        validation_strategy = self.init_validation_strategy(data_instances, validate_data)

        pubkey = self.cipher.gen_paillier_pubkey(enable=self.use_encrypt, suffix=('fit',))
        if self.use_encrypt:
            self.cipher_operator.set_public_key(pubkey)

        self.model_weights = self._init_model_variables(data_instances)
        w = self.cipher_operator.encrypt_list(self.model_weights.unboxed)
        self.model_weights = LogisticRegressionWeights(w, self.model_weights.fit_intercept)

        LOGGER.debug("After init, model_weights: {}".format(self.model_weights.unboxed))

        mini_batch_obj = MiniBatch(data_inst=data_instances, batch_size=self.batch_size)

        total_batch_num = mini_batch_obj.batch_nums

        if self.use_encrypt:
            re_encrypt_times = total_batch_num // self.re_encrypt_batches + 1
            LOGGER.debug("re_encrypt_times is :{}, batch_size: {}, total_batch_num: {}, re_encrypt_batches: {}".format(
                re_encrypt_times, self.batch_size, total_batch_num, self.re_encrypt_batches))
            self.cipher.set_re_cipher_time(re_encrypt_times)

        total_data_num = data_instances.count()
        LOGGER.debug("Current data count: {}".format(total_data_num))

        model_weights = self.model_weights
        degree = 0

        self.__synchronize_encryption()
        self.zcl_idx, self.zcl_num_party = self.transfer_variable.num_party.get(idx=0, suffix=('train',))
        LOGGER.debug("party num:" + str(self.zcl_num_party))
        self.__init_model()

        self.train_loss_results = []
        self.train_accuracy_results = []
        self.test_loss_results = []
        self.test_accuracy_results = []

        for iter_num in range(self.max_iter):
            # mini-batch
            LOGGER.debug("In iter: {}".format(iter_num))
            # batch_data_generator = self.mini_batch_obj.mini_batch_data_generator()
            batch_num = 0
            total_loss = 0
            epoch_train_loss_avg = tfe.metrics.Mean()
            epoch_train_accuracy = tfe.metrics.Accuracy()

            for train_x, train_y in self.zcl_dataset:
                LOGGER.info("Staring batch {}".format(batch_num))
                start_t = time.time()
                loss_value, grads = self.__grad(self.zcl_model, train_x, train_y)
                loss_value = loss_value.numpy()
                grads = [x.numpy() for x in grads]
                LOGGER.info("Start encrypting")
                loss_value = batch_encryption.encrypt(self.zcl_encrypt_operator.get_public_key(), loss_value)
                grads = [batch_encryption.encrypt_matrix(self.zcl_encrypt_operator.get_public_key(), x) for x in grads]
                LOGGER.info("Finish encrypting")
                grads = Gradients(grads)
                self.transfer_variable.host_grad.remote(obj=grads.for_remote(), role=consts.ARBITER, idx=0, suffix=(iter_num, batch_num))
                LOGGER.info("Sent grads")
                self.transfer_variable.host_loss.remote(obj=loss_value, role=consts.ARBITER, idx=0, suffix=(iter_num, batch_num))
                LOGGER.info("Sent loss")

                sum_grads = self.transfer_variable.aggregated_grad.get(idx=0, suffix=(iter_num, batch_num))
                LOGGER.info("Got grads")
                sum_loss = self.transfer_variable.aggregated_loss.get(idx=0, suffix=(iter_num, batch_num))
                LOGGER.info("Got loss")

                sum_loss = batch_encryption.decrypt(self.zcl_encrypt_operator.get_privacy_key(), sum_loss)
                sum_grads = [
                    batch_encryption.decrypt_matrix(self.zcl_encrypt_operator.get_privacy_key(), x).astype(np.float32)
                    for x
                    in sum_grads.unboxed]
                LOGGER.info("Finish decrypting")

                # sum_grads = np.array(sum_grads) / self.zcl_num_party

                self.zcl_optimizer.apply_gradients(zip(sum_grads, self.zcl_model.trainable_variables),
                                                   self.zcl_global_step)

                elapsed_time = time.time() - start_t
                # epoch_train_loss_avg(loss_value)
                # epoch_train_accuracy(tf.argmax(self.zcl_model(train_x), axis=1, output_type=tf.int32),
                #                      train_y)
                self.train_loss_results.append(sum_loss)
                train_accuracy_v = accuracy_score(train_y,
                                                  tf.argmax(self.zcl_model(train_x), axis=1, output_type=tf.int32))
                self.train_accuracy_results.append(train_accuracy_v)
                test_loss_v = self.__loss(self.zcl_model, self.zcl_x_test, self.zcl_y_test)
                self.test_loss_results.append(test_loss_v)
                test_accuracy_v = accuracy_score(self.zcl_y_test,
                                                 tf.argmax(self.zcl_model(self.zcl_x_test), axis=1,
                                                           output_type=tf.int32))
                self.test_accuracy_results.append(test_accuracy_v)

                LOGGER.info(
                    "Epoch {:03d}, iteration {:03d}: train_loss: {:.3f}, train_accuracy: {:.3%}, test_loss: {:.3f}, "
                    "test_accuracy: {:.3%}, elapsed_time: {:.4f}".format(
                        iter_num,
                        batch_num,
                        sum_loss,
                        train_accuracy_v,
                        test_loss_v,
                        test_accuracy_v,
                        elapsed_time)
                )

                batch_num += 1

                if batch_num >= self.zcl_early_stop_batch:
                    return

            self.n_iter_ = iter_num

    def __synchronize_encryption(self, mode='train'):
        """
        Communicate with hosts. Specify whether use encryption or not and transfer the public keys.
        """
        pub_key = self.transfer_variable.paillier_pubkey.get(idx=0, suffix=(mode,))
        LOGGER.debug("Received pubkey")
        self.zcl_encrypt_operator.set_public_key(pub_key)
        pri_key = self.transfer_variable.paillier_prikey.get(idx=0, suffix=(mode,))
        LOGGER.debug("Received prikey")
        self.zcl_encrypt_operator.set_privacy_key(pri_key)

    def __init_model(self):
        # self.zcl_model = keras.Sequential([
        #     keras.layers.Flatten(input_shape=(28, 28)),
        #     keras.layers.Dense(128, activation=tf.nn.relu),
        #     keras.layers.Dense(10, activation=tf.nn.softmax)
        # ])
        #
        # LOGGER.info("Initialed model")
        json_file = open(MODEL_JSON_DIR, 'r')
        loaded_model_json = json_file.read()
        json_file.close()
        loaded_model = keras.models.model_from_json(loaded_model_json)
        loaded_model.load_weights(MODEL_WEIGHT_DIR)
        self.zcl_model = loaded_model
        LOGGER.info("Initialed model")

        # The data, split between train and test sets:
        (x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()

        x_train = x_train.astype('float32')
        x_test = x_test.astype('float32')
        x_train /= 255.0
        x_test /= 255.0
        y_train = y_train.squeeze().astype(np.int32)
        y_test = y_test.squeeze().astype(np.int32)

        avg_length = int(len(x_train) / self.zcl_num_party)
        split_idx = [_ * avg_length for _ in range(1, self.zcl_num_party)]
        x_train = np.split(x_train, split_idx)[self.zcl_idx]
        y_train = np.split(y_train, split_idx)[self.zcl_idx]

        train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
        BATCH_SIZE = 128
        SHUFFLE_BUFFER_SIZE = 1000
        train_dataset = train_dataset.shuffle(SHUFFLE_BUFFER_SIZE, reshuffle_each_iteration=True).batch(BATCH_SIZE)
        self.zcl_dataset = train_dataset
        self.zcl_x_test = x_test
        self.zcl_y_test = y_test

        self.zcl_cce = tf.keras.losses.SparseCategoricalCrossentropy()
        self.zcl_optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)
        self.zcl_global_step = tf.Variable(0)

    def __loss(self, model, x, y):
        y_ = model(x)
        return self.zcl_cce(y_true=y, y_pred=y_)

    def __grad(self, model, inputs, targets):
        with tf.GradientTape() as tape:
            loss_value = self.__loss(model, inputs, targets)
        return loss_value, tape.gradient(loss_value, model.trainable_variables)

    def __clip_gradients(self, grads, min_v, max_v):
        results = [tf.clip_by_value(t, min_v, max_v).numpy() for t in grads]
        return results

    def predict(self, data_instances):

        LOGGER.info(f'Start predict task')
        self._abnormal_detection(data_instances)
        self.init_schema(data_instances)
        suffix = ('predict',)
        pubkey = self.cipher.gen_paillier_pubkey(enable=self.use_encrypt, suffix=suffix)
        if self.use_encrypt:
            self.cipher_operator.set_public_key(pubkey)

        if self.use_encrypt:
            final_model = self.transfer_variable.aggregated_model.get(idx=0, suffix=suffix)
            model_weights = LogisticRegressionWeights(final_model.unboxed, self.fit_intercept)
            wx = self.compute_wx(data_instances, model_weights.coef_, model_weights.intercept_)
            self.transfer_variable.predict_wx.remote(wx, consts.ARBITER, 0, suffix=suffix)
            predict_result = self.transfer_variable.predict_result.get(idx=0, suffix=suffix)
            predict_result = predict_result.join(data_instances, lambda p, d: [d.label, p, None,
                                                                                     {"0": None, "1": None}])

        else:
            predict_wx = self.compute_wx(data_instances, self.model_weights.coef_, self.model_weights.intercept_)
            pred_table = self.classify(predict_wx, self.model_param.predict_param.threshold)
            predict_result = data_instances.mapValues(lambda x: x.label)
            predict_result = pred_table.join(predict_result, lambda x, y: [y, x[1], x[0],
                                                                           {"1": x[0], "0": 1 - x[0]}])
        return predict_result

    def _get_param(self):
        header = self.header

        weight_dict = {}
        intercept = 0
        if not self.use_encrypt:
            lr_vars = self.model_weights.coef_
            for idx, header_name in enumerate(header):
                coef_i = lr_vars[idx]
                weight_dict[header_name] = coef_i
            intercept = self.model_weights.intercept_

        param_protobuf_obj = lr_model_param_pb2.LRModelParam(iters=self.n_iter_,
                                                             loss_history=self.loss_history,
                                                             is_converged=self.is_converged,
                                                             weight=weight_dict,
                                                             intercept=intercept,
                                                             header=header)
        from google.protobuf import json_format
        json_result = json_format.MessageToJson(param_protobuf_obj)
        LOGGER.debug("json_result: {}".format(json_result))
        return param_protobuf_obj
class HeteroSSHEGuestBase(HeteroSSHEBase, ABC):
    def __init__(self):
        super().__init__()
        self.role = consts.GUEST
        self.local_party = get_parties().local_party
        self.other_party = get_parties().roles_to_parties(["host"])[0]
        self.parties = [self.local_party] + [self.other_party]
        self.encrypted_error = None
        self.encrypted_wx = None
        self.z_square = None
        self.wx_self = None
        self.wx_remote = None

    def _init_model(self, params):
        super()._init_model(params)
        # self.batch_generator = batch_generator.Guest()
        # self.batch_generator.register_batch_generator(BatchGeneratorTransferVariable(), has_arbiter=False)

    def _transfer_q_field(self):
        q_field = self.cipher.public_key.n
        self.transfer_variable.q_field.remote(q_field, role=consts.HOST, suffix=("q_field",))

        return q_field

    def _cal_z(self, weights, features, suffix, cipher):
        if not self.reveal_every_iter:
            LOGGER.info(f"[forward]: Calculate z in share...")
            w_self, w_remote = weights
            z = self._cal_z_in_share(w_self, w_remote, features, suffix, cipher)
        else:
            LOGGER.info(f"[forward]: Calculate z directly...")
            w = weights.unboxed
            z = features.dot_local(w)

        remote_z = self.secure_matrix_obj.share_encrypted_matrix(suffix=suffix,
                                                                 is_remote=False,
                                                                 cipher=None,
                                                                 z=None)[0]

        self.wx_self = z
        self.wx_remote = remote_z

    def _cal_z_in_share(self, w_self, w_remote, features, suffix, cipher):
        z1 = features.dot_local(w_self)

        za_suffix = ("za",) + suffix

        za_share = self.secure_matrix_obj.secure_matrix_mul(w_remote,
                                                            tensor_name=".".join(za_suffix),
                                                            cipher=cipher,
                                                            suffix=za_suffix)
        zb_suffix = ("zb",) + suffix
        zb_share = self.secure_matrix_obj.secure_matrix_mul(features,
                                                            tensor_name=".".join(zb_suffix),
                                                            cipher=None,
                                                            suffix=zb_suffix)

        z = z1 + za_share + zb_share
        return z

    def backward(self, error, features, suffix, cipher):
        LOGGER.info(f"[backward]: Calculate gradient...")
        batch_num = self.batch_num[int(suffix[1])]
        error_1_n = error * (1 / batch_num)

        ga2_suffix = ("ga2",) + suffix
        ga2_2 = self.secure_matrix_obj.secure_matrix_mul(error_1_n,
                                                         tensor_name=".".join(ga2_suffix),
                                                         cipher=cipher,
                                                         suffix=ga2_suffix,
                                                         is_fixedpoint_table=False)

        # LOGGER.debug(f"ga2_2: {ga2_2}")

        encrypt_g = self.encrypted_error.dot(features) * (1 / batch_num)

        # LOGGER.debug(f"encrypt_g: {encrypt_g}")

        tensor_name = ".".join(("encrypt_g",) + suffix)
        gb2 = SecureMatrix.from_source(tensor_name,
                                       encrypt_g,
                                       self.cipher,
                                       self.fixedpoint_encoder.n,
                                       self.fixedpoint_encoder)

        # LOGGER.debug(f"gb2: {gb2}")

        return gb2, ga2_2

    def share_model(self, w, suffix):
        source = [w, self.other_party]
        wb, wa = (
            fixedpoint_numpy.FixedPointTensor.from_source(f"wb_{suffix}", source[0],
                                                          encoder=self.fixedpoint_encoder,
                                                          q_field=self.q_field),
            fixedpoint_numpy.FixedPointTensor.from_source(f"wa_{suffix}", source[1],
                                                          encoder=self.fixedpoint_encoder,
                                                          q_field=self.q_field),
        )
        return wb, wa

    def reveal_models(self, w_self, w_remote, suffix=None):
        if suffix is None:
            suffix = self.n_iter_

        if self.model_param.reveal_strategy == "respectively":

            new_w = w_self.get(tensor_name=f"wb_{suffix}",
                               broadcast=False)
            w_remote.broadcast_reconstruct_share(tensor_name=f"wa_{suffix}")

        elif self.model_param.reveal_strategy == "encrypted_reveal_in_host":

            new_w = w_self.get(tensor_name=f"wb_{suffix}",
                               broadcast=False)
            encrypted_w_remote = self.cipher.recursive_encrypt(self.fixedpoint_encoder.decode(w_remote.value))
            encrypted_w_remote_tensor = fixedpoint_numpy.PaillierFixedPointTensor(value=encrypted_w_remote)
            encrypted_w_remote_tensor.broadcast_reconstruct_share(tensor_name=f"wa_{suffix}")

        else:
            raise NotImplementedError(f"reveal strategy: {self.model_param.reveal_strategy} has not been implemented.")
        return new_w

    def _reveal_every_iter_weights_check(self, last_w, new_w, suffix):
        square_sum = np.sum((last_w - new_w) ** 2)
        host_sums = self.converge_transfer_variable.square_sum.get(suffix=suffix)
        for hs in host_sums:
            square_sum += hs
        weight_diff = np.sqrt(square_sum)
        is_converge = False
        if weight_diff < self.model_param.tol:
            is_converge = True
        LOGGER.info(f"n_iter: {self.n_iter_}, weight_diff: {weight_diff}")
        self.converge_transfer_variable.converge_info.remote(is_converge, role=consts.HOST, suffix=suffix)
        return is_converge

    def check_converge_by_loss(self, loss, suffix):
        self.is_converged = self.converge_func.is_converge(loss)
        self.transfer_variable.is_converged.remote(self.is_converged, suffix=suffix)

        return self.is_converged

    def prepare_fit(self, data_instances, validate_data):
        # self.transfer_variable = SSHEModelTransferVariable()
        self.batch_generator = batch_generator.Guest()
        self.batch_generator.register_batch_generator(BatchGeneratorTransferVariable(), has_arbiter=False)
        self.header = copy.deepcopy(data_instances.schema.get("header", []))
        self._abnormal_detection(data_instances)
        self.check_abnormal_values(data_instances)
        self.check_abnormal_values(validate_data)

    def get_single_model_param(self, model_weights=None, header=None):
        result = super().get_single_model_param(model_weights, header)
        result['weight'] = self.get_single_model_weight_dict(model_weights, header)
        if not self.is_respectively_reveal:
            result["cipher"] = dict(public_key=dict(n=str(self.cipher.public_key.n)),
                                    private_key=dict(p=str(self.cipher.privacy_key.p),
                                                     q=str(self.cipher.privacy_key.q)))

        return result

    def load_single_model(self, single_model_obj):
        LOGGER.info("start to load single model")

        self.load_single_model_weight(single_model_obj)
        self.n_iter_ = single_model_obj.iters

        if not self.is_respectively_reveal:
            cipher_info = single_model_obj.cipher
            self.cipher = PaillierEncrypt()
            public_key = PaillierPublicKey(int(cipher_info.public_key.n))
            privacy_key = PaillierPrivateKey(public_key, int(cipher_info.private_key.p), int(cipher_info.private_key.q))
            self.cipher.set_public_key(public_key=public_key)
            self.cipher.set_privacy_key(privacy_key=privacy_key)

        return self