def create(cls, ftl_model_param: FTLModelParam,
               transfer_variable: HeteroFTLTransferVariable, ftl_local_model):
        if ftl_model_param.is_encrypt:
            if ftl_model_param.enc_ftl == "dct_enc_ftl":
                # decentralized encrypted ftl host
                LOGGER.debug("@ create decentralized encrypted ftl_host")
                host_model = EncryptedFTLHostModel(local_model=ftl_local_model,
                                                   model_param=ftl_model_param)
                host = HeteroDecentralizedEncryptFTLHost(
                    host_model, ftl_model_param, transfer_variable)
            elif ftl_model_param.enc_ftl == "dct_enc_ftl2":
                # decentralized encrypted faster ftl host
                LOGGER.debug(
                    "@ create decentralized encrypted faster ftl_host")
                host_model = FasterEncryptedFTLHostModel(
                    local_model=ftl_local_model, model_param=ftl_model_param)
                host = FasterHeteroDecentralizedEncryptFTLHost(
                    host_model, ftl_model_param, transfer_variable)
            elif ftl_model_param.enc_ftl == "enc_ftl2":
                # encrypted faster ftl host
                LOGGER.debug("@ create encrypted faster ftl_host")
                host_model = FasterEncryptedFTLHostModel(
                    local_model=ftl_local_model, model_param=ftl_model_param)
                host = FasterHeteroEncryptFTLHost(host_model, ftl_model_param,
                                                  transfer_variable)
            #new
            elif ftl_model_param.enc_ftl == "enc_ot":
                # encrypted OT host
                LOGGER.debug("@ create encrypt OT host")
                host_model = OTEncryptedFTLHostModel(
                    local_model=ftl_local_model, model_param=ftl_model_param)
                host = HeteroOTEncryptFTLhost(host_model, ftl_model_param,
                                              transfer_variable)
            else:
                # encrypted ftl host
                LOGGER.debug("@ create encrypted ftl_host")
                host_model = EncryptedFTLHostModel(local_model=ftl_local_model,
                                                   model_param=ftl_model_param)
                host = HeteroEncryptFTLHost(host_model, ftl_model_param,
                                            transfer_variable)

        else:
            # plain ftl host
            LOGGER.debug("@ create plain ftl_host")
            host_model = PlainFTLHostModel(local_model=ftl_local_model,
                                           model_param=ftl_model_param)
            host = HeteroPlainFTLHost(host_model, ftl_model_param,
                                      transfer_variable)
        return host
Exemple #2
0
def run_one_party_msg_exchange(autoencoderA,
                               autoencoderB,
                               U_A,
                               U_B,
                               y,
                               overlap_indexes,
                               non_overlap_indexes,
                               public_key=None,
                               private_key=None):

    fake_model_param = MockFTLModelParam(alpha=1)

    partyA = FasterEncryptedFTLGuestModel(autoencoderA,
                                          fake_model_param,
                                          public_key=public_key,
                                          private_key=private_key)
    partyA.set_batch(U_A, y, non_overlap_indexes, overlap_indexes)
    partyB = FasterEncryptedFTLHostModel(autoencoderB,
                                         fake_model_param,
                                         public_key=public_key,
                                         private_key=private_key)
    partyB.set_batch(U_B, overlap_indexes)

    [y_overlap_phi, mapping_comp_A, phi, phi_2] = partyA.send_components()
    [uB_overlap, mapping_comp_B] = partyB.send_components()

    partyA.receive_components([uB_overlap, mapping_comp_B])
    partyB.receive_components([y_overlap_phi, mapping_comp_A, phi, phi_2])

    precomputed_components_A = partyA.send_precomputed_components()
    precomputed_components_B = partyB.send_precomputed_components()

    partyA.receive_precomputed_components(precomputed_components_B)
    partyB.receive_precomputed_components(precomputed_components_A)

    # encrypt_gradients_A = partyA.send_gradients()
    # encrypt_gradients_B = partyB.send_gradients()

    # partyA.receive_gradients(__decrypt_gradients(encrypt_gradients_A))
    # partyB.receive_gradients(__decrypt_gradients(encrypt_gradients_B))

    return partyA, partyB
Exemple #3
0
    tf.reset_default_graph()

    autoencoder_A = Autoencoder(1)
    autoencoder_B = Autoencoder(2)

    autoencoder_A.build(X_A.shape[-1], 32, learning_rate=0.01)
    autoencoder_B.build(X_B.shape[-1], 32, learning_rate=0.01)

    paillierEncrypt = PaillierEncrypt()
    paillierEncrypt.generate_key()
    publickey = paillierEncrypt.get_public_key()
    privatekey = paillierEncrypt.get_privacy_key()

    fake_model_param = FakeFTLModelParam(alpha=100)
    partyA = FasterEncryptedFTLGuestModel(autoencoder_A, fake_model_param, public_key=publickey)
    partyB = FasterEncryptedFTLHostModel(autoencoder_B, fake_model_param, public_key=publickey)

    federatedLearning = LocalFasterEncryptedFederatedTransferLearning(partyA, partyB, privatekey)

    print("################################ Train Federated Models ############################")
    start_time = time.time()
    epochs = 10
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        autoencoder_A.set_session(sess)
        autoencoder_B.set_session(sess)

        sess.run(init)
        losses = []
        fscores = []
        aucs = []