Example #1
0
    def __decrypt_gradients(self, enc_grads, is_host, n_iter, batch_index):

        if is_host:
            remote_name = self.transfer_variable.host_enc_gradient.name
            get_name = self.transfer_variable.host_dec_gradient.name
            remote_tag = self.transfer_variable.generate_transferid(
                self.transfer_variable.host_enc_gradient, n_iter, batch_index)
            get_tag = self.transfer_variable.generate_transferid(
                self.transfer_variable.host_dec_gradient, n_iter, batch_index)
        else:
            remote_name = self.transfer_variable.guest_enc_gradient.name
            get_name = self.transfer_variable.guest_dec_gradient.name
            remote_tag = self.transfer_variable.generate_transferid(
                self.transfer_variable.guest_enc_gradient, n_iter, batch_index)
            get_tag = self.transfer_variable.generate_transferid(
                self.transfer_variable.guest_dec_gradient, n_iter, batch_index)

        masked_enc_grads, grads_mask = add_random_mask(enc_grads)
        self.federation_client.remote(masked_enc_grads,
                                      name=remote_name,
                                      tag=remote_tag,
                                      role=consts.ARBITER,
                                      idx=0)

        masked_dec_grads = self.federation_client.get(name=get_name,
                                                      tag=get_tag,
                                                      idx=0)
        cleared_dec_grads = remove_random_mask(masked_dec_grads, grads_mask)

        return cleared_dec_grads
Example #2
0
    def __test_matrix(self, matrix):

        paillierEncrypt = PaillierEncrypt()
        paillierEncrypt.generate_key()
        publickey = paillierEncrypt.get_public_key()
        privatekey = paillierEncrypt.get_privacy_key()

        enc_matrix = encrypt_matrix(publickey, matrix)
        masked_enc_matrix, mask = add_random_mask(enc_matrix)

        cleared_enc_matrix = remove_random_mask(masked_enc_matrix, mask)
        cleared_matrix = decrypt_matrix(privatekey, cleared_enc_matrix)
        assert_matrix(matrix, cleared_matrix)
Example #3
0
    def __test_scalar(self, value):

        paillierEncrypt = PaillierEncrypt()
        paillierEncrypt.generate_key()
        publickey = paillierEncrypt.get_public_key()
        privatekey = paillierEncrypt.get_privacy_key()

        enc_value = publickey.encrypt(value)
        masked_enc_value, mask = add_random_mask(enc_value)

        cleared_enc_value = remove_random_mask(masked_enc_value, mask)
        cleared_value = privatekey.decrypt(cleared_enc_value)
        print("original matrix", value)
        print("cleared_matrix", cleared_value)
        self.assertEqual(value, cleared_value)
Example #4
0
    def fit(self, guest_data):
        LOGGER.info("@ start guest fit")
        self.prepare_encryption_key_pair()
        guest_x, overlap_indexes, non_overlap_indexes, guest_y = self.prepare_data(
            guest_data)

        LOGGER.debug("guest_x: " + str(guest_x.shape))
        LOGGER.debug("guest_y: " + str(guest_y.shape))
        LOGGER.debug("overlap_indexes: " + str(len(overlap_indexes)))
        LOGGER.debug("non_overlap_indexes: " + str(len(non_overlap_indexes)))
        LOGGER.debug("converge eps: " + str(self.converge_func.eps))

        self.guest_model.set_batch(guest_x, guest_y, non_overlap_indexes,
                                   overlap_indexes)
        self.guest_model.set_public_key(self.public_key)
        self.guest_model.set_host_public_key(self.host_public_key)
        self.guest_model.set_private_key(self.private_key)

        start_time = time.time()
        is_stop = False
        while self.n_iter_ < self.max_iter:

            # Stage 1: compute and encrypt components (using guest public key) required by host to
            #          calculate gradients.
            LOGGER.debug("@ Stage 1: ")
            guest_comp = self.guest_model.send_components()
            LOGGER.debug("send enc guest_comp: " +
                         create_shape_msg(guest_comp))
            self._do_remote(
                guest_comp,
                name=self.transfer_variable.guest_component_list.name,
                tag=self.transfer_variable.generate_transferid(
                    self.transfer_variable.guest_component_list, self.n_iter_),
                role=consts.HOST,
                idx=-1)

            # Stage 2: receive host components in encrypted form (encrypted by host public key),
            #          calculate guest gradients and loss in encrypted form (encrypted by host public key),
            #          and send them to host for decryption
            LOGGER.debug("@ Stage 2: ")
            host_comp = self._do_get(
                name=self.transfer_variable.host_component_list.name,
                tag=self.transfer_variable.generate_transferid(
                    self.transfer_variable.host_component_list, self.n_iter_),
                idx=-1)[0]
            LOGGER.debug("receive enc host_comp: " +
                         create_shape_msg(host_comp))
            self.guest_model.receive_components(host_comp)

            self._precompute()

            # calculate guest gradients in encrypted form (encrypted by host public key)
            encrypt_guest_gradients = self.guest_model.send_gradients()
            LOGGER.debug("compute encrypt_guest_gradients: " +
                         create_shape_msg(encrypt_guest_gradients))
            encrypt_loss = self.guest_model.send_loss()

            # add random mask to encrypt_guest_gradients and encrypt_loss, and send them to host for decryption
            masked_enc_guest_gradients, gradients_masks = add_random_mask_for_list_of_values(
                encrypt_guest_gradients)
            masked_enc_loss, loss_mask = add_random_mask(encrypt_loss)

            LOGGER.debug("send masked_enc_guest_gradients: " +
                         create_shape_msg(masked_enc_guest_gradients))
            self._do_remote(
                masked_enc_guest_gradients,
                name=self.transfer_variable.masked_enc_guest_gradients.name,
                tag=self.transfer_variable.generate_transferid(
                    self.transfer_variable.masked_enc_guest_gradients,
                    self.n_iter_),
                role=consts.HOST,
                idx=-1)

            self._do_remote(masked_enc_loss,
                            name=self.transfer_variable.masked_enc_loss.name,
                            tag=self.transfer_variable.generate_transferid(
                                self.transfer_variable.masked_enc_loss,
                                self.n_iter_),
                            role=consts.HOST,
                            idx=-1)

            # Stage 3: receive and then decrypt masked encrypted host gradients and send them to guest
            LOGGER.debug("@ Stage 3: ")
            masked_enc_host_gradients = self._do_get(
                name=self.transfer_variable.masked_enc_host_gradients.name,
                tag=self.transfer_variable.generate_transferid(
                    self.transfer_variable.masked_enc_host_gradients,
                    self.n_iter_),
                idx=-1)[0]

            masked_dec_host_gradients = self.__decrypt_gradients(
                masked_enc_host_gradients)

            LOGGER.debug("send masked_dec_host_gradients: " +
                         create_shape_msg(masked_dec_host_gradients))
            self._do_remote(
                masked_dec_host_gradients,
                name=self.transfer_variable.masked_dec_host_gradients.name,
                tag=self.transfer_variable.generate_transferid(
                    self.transfer_variable.masked_dec_host_gradients,
                    self.n_iter_),
                role=consts.HOST,
                idx=-1)

            # Stage 4: receive masked but decrypted guest gradients and loss from host, remove mask,
            #          and update guest model parameters using these gradients.
            LOGGER.debug("@ Stage 4: ")
            masked_dec_guest_gradients = self._do_get(
                name=self.transfer_variable.masked_dec_guest_gradients.name,
                tag=self.transfer_variable.generate_transferid(
                    self.transfer_variable.masked_dec_guest_gradients,
                    self.n_iter_),
                idx=-1)[0]
            LOGGER.debug("receive masked_dec_guest_gradients: " +
                         create_shape_msg(masked_dec_guest_gradients))

            cleared_dec_guest_gradients = remove_random_mask_from_list_of_values(
                masked_dec_guest_gradients, gradients_masks)

            # update guest model parameters using these gradients.
            self.guest_model.receive_gradients(cleared_dec_guest_gradients)

            masked_dec_loss = self._do_get(
                name=self.transfer_variable.masked_dec_loss.name,
                tag=self.transfer_variable.generate_transferid(
                    self.transfer_variable.masked_dec_loss, self.n_iter_),
                idx=-1)[0]
            LOGGER.debug("receive masked_dec_loss: " + str(masked_dec_loss))

            loss = remove_random_mask(masked_dec_loss, loss_mask)

            # Stage 5: determine whether training is terminated based on loss and send stop signal to host.
            LOGGER.debug("@ Stage 5: ")
            if self.converge_func.is_converge(loss):
                is_stop = True

            # send is_stop indicator to host
            self._do_remote(
                is_stop,
                name=self.transfer_variable.is_decentralized_enc_ftl_stopped.
                name,
                tag=self.transfer_variable.generate_transferid(
                    self.transfer_variable.is_decentralized_enc_ftl_stopped,
                    self.n_iter_),
                role=consts.HOST,
                idx=-1)

            LOGGER.info("@ time: " + str(time.time()) + ", ep:" +
                        str(self.n_iter_) + ", loss:" + str(loss))
            LOGGER.info("@ converged: " + str(is_stop))
            self.n_iter_ += 1
            if is_stop:
                break

        end_time = time.time()
        LOGGER.info("@ running time: " + str(end_time - start_time))