Beispiel #1
0
class TestHeteroFTLGuest(HeteroPlainFTLGuest):
    def __init__(self, guest, model_param, transfer_variable):
        super(TestHeteroFTLGuest, self).__init__(guest, model_param,
                                                 transfer_variable)

        U_B = np.array([[4, 2, 3, 1, 2], [6, 5, 1, 4, 5], [7, 4, 1, 9, 10],
                        [6, 5, 1, 4, 5]])

        overlap_indexes = [1, 2]

        Wh = np.ones((5, U_B.shape[1]))
        bh = np.zeros(U_B.shape[1])

        autoencoderB = MockAutoencoder(1)
        autoencoderB.build(U_B.shape[1], Wh, bh)

        self.host = PlainFTLHostModel(autoencoderB, self.model_param)
        self.host.set_batch(U_B, overlap_indexes)

    def _do_remote(self, value=None, name=None, tag=None, role=None, idx=None):
        print("@_do_remote", value, name, tag, role, idx)

    def _do_get(self, name=None, tag=None, idx=None):
        print("@_do_get", name, tag, idx)
        if tag == "HeteroFTLTransferVariable.host_sample_indexes.0":
            return [np.array([1, 2, 4, 5])]
        elif tag == "HeteroFTLTransferVariable.host_component_list.0.0":
            return self.host.send_components()
        return None
Beispiel #2
0
    def __init__(self, guest, model_param, transfer_variable):
        super(TestHeteroFTLGuest, self).__init__(guest, model_param,
                                                 transfer_variable)

        U_B = np.array([[4, 2, 3, 1, 2], [6, 5, 1, 4, 5], [7, 4, 1, 9, 10],
                        [6, 5, 1, 4, 5]])

        overlap_indexes = [1, 2]

        Wh = np.ones((5, U_B.shape[1]))
        bh = np.zeros(U_B.shape[1])

        autoencoderB = MockAutoencoder(1)
        autoencoderB.build(U_B.shape[1], Wh, bh)

        self.host = PlainFTLHostModel(autoencoderB, self.model_param)
        self.host.set_batch(U_B, overlap_indexes)
    def create(cls, ftl_model_param: FTLModelParam,
               transfer_variable: HeteroFTLTransferVariable, ftl_local_model):
        if ftl_model_param.is_encrypt:
            if ftl_model_param.enc_ftl == "dct_enc_ftl":
                # decentralized encrypted ftl host
                LOGGER.debug("@ create decentralized encrypted ftl_host")
                host_model = EncryptedFTLHostModel(local_model=ftl_local_model,
                                                   model_param=ftl_model_param)
                host = HeteroDecentralizedEncryptFTLHost(
                    host_model, ftl_model_param, transfer_variable)
            elif ftl_model_param.enc_ftl == "dct_enc_ftl2":
                # decentralized encrypted faster ftl host
                LOGGER.debug(
                    "@ create decentralized encrypted faster ftl_host")
                host_model = FasterEncryptedFTLHostModel(
                    local_model=ftl_local_model, model_param=ftl_model_param)
                host = FasterHeteroDecentralizedEncryptFTLHost(
                    host_model, ftl_model_param, transfer_variable)
            elif ftl_model_param.enc_ftl == "enc_ftl2":
                # encrypted faster ftl host
                LOGGER.debug("@ create encrypted faster ftl_host")
                host_model = FasterEncryptedFTLHostModel(
                    local_model=ftl_local_model, model_param=ftl_model_param)
                host = FasterHeteroEncryptFTLHost(host_model, ftl_model_param,
                                                  transfer_variable)
            #new
            elif ftl_model_param.enc_ftl == "enc_ot":
                # encrypted OT host
                LOGGER.debug("@ create encrypt OT host")
                host_model = OTEncryptedFTLHostModel(
                    local_model=ftl_local_model, model_param=ftl_model_param)
                host = HeteroOTEncryptFTLhost(host_model, ftl_model_param,
                                              transfer_variable)
            else:
                # encrypted ftl host
                LOGGER.debug("@ create encrypted ftl_host")
                host_model = EncryptedFTLHostModel(local_model=ftl_local_model,
                                                   model_param=ftl_model_param)
                host = HeteroEncryptFTLHost(host_model, ftl_model_param,
                                            transfer_variable)

        else:
            # plain ftl host
            LOGGER.debug("@ create plain ftl_host")
            host_model = PlainFTLHostModel(local_model=ftl_local_model,
                                           model_param=ftl_model_param)
            host = HeteroPlainFTLHost(host_model, ftl_model_param,
                                      transfer_variable)
        return host
Beispiel #4
0
 def _do_initialize_model(self, ftl_model_param, ftl_local_model_param,
                          ftl_data_param):
     self.ftl_local_model = self._create_local_model(
         ftl_local_model_param, ftl_data_param)
     if ftl_model_param.is_encrypt:
         LOGGER.debug("@ create encrypt ftl_host")
         host_model = EncryptedFTLHostModel(
             local_model=self.ftl_local_model, model_param=ftl_model_param)
         self.model = HeteroEncryptFTLHost(host_model, ftl_model_param,
                                           self._get_transfer_variable())
     else:
         LOGGER.debug("@ create plain ftl_host")
         host_model = PlainFTLHostModel(local_model=self.ftl_local_model,
                                        model_param=ftl_model_param)
         self.model = HeteroPlainFTLHost(host_model, ftl_model_param,
                                         self._get_transfer_variable())
def run_one_party_msg_exchange(autoencoderA,
                               autoencoderB,
                               U_A,
                               U_B,
                               y,
                               overlap_indexes,
                               non_overlap_indexes,
                               public_key=None,
                               private_key=None,
                               is_encrypted=False):

    fake_model_param = FakeFTLModelParam(alpha=1)
    if is_encrypted:
        partyA = EncryptedFTLGuestModel(autoencoderA,
                                        fake_model_param,
                                        public_key=public_key,
                                        private_key=private_key)
        partyA.set_batch(U_A, y, non_overlap_indexes, overlap_indexes)
        partyB = EncryptedFTLHostModel(autoencoderB,
                                       fake_model_param,
                                       public_key=public_key,
                                       private_key=private_key)
        partyB.set_batch(U_B, overlap_indexes)
    else:
        partyA = PlainFTLGuestModel(autoencoderA, fake_model_param)
        partyA.set_batch(U_A, y, non_overlap_indexes, overlap_indexes)
        partyB = PlainFTLHostModel(autoencoderB, fake_model_param)
        partyB.set_batch(U_B, overlap_indexes)

    comp_A_beta1, comp_A_beta2, mapping_comp_A = partyA.send_components()
    U_B_overlap, U_B_overlap_2, mapping_comp_B = partyB.send_components()

    partyA.receive_components([U_B_overlap, U_B_overlap_2, mapping_comp_B])
    partyB.receive_components([comp_A_beta1, comp_A_beta2, mapping_comp_A])

    return partyA, partyB
    print(
        "################################ Build Federated Models ############################"
    )

    tf.reset_default_graph()

    autoencoder_A = Autoencoder(1)
    autoencoder_B = Autoencoder(2)

    autoencoder_A.build(X_A.shape[-1], 200, learning_rate=0.01)
    autoencoder_B.build(X_B.shape[-1], 200, learning_rate=0.01)

    # alpha = 100
    fake_model_param = FakeFTLModelParam()
    partyA = PlainFTLGuestModel(autoencoder_A, fake_model_param)
    partyB = PlainFTLHostModel(autoencoder_B, fake_model_param)

    federatedLearning = LocalPlainFederatedTransferLearning(partyA, partyB)

    print(
        "################################ Train Federated Models ############################"
    )
    start_time = time.time()
    epochs = 100
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        autoencoder_A.set_session(sess)
        autoencoder_B.set_session(sess)

        sess.run(init)
        losses = []