def fit(self, host_data): LOGGER.info("@ start host fit") self.prepare_encryption_key_pair() host_x, overlap_indexes = self.prepare_data(host_data) LOGGER.debug("host_x: " + str(host_x.shape)) LOGGER.debug("overlap_indexes: " + str(len(overlap_indexes))) self.host_model.set_batch(host_x, overlap_indexes) self.host_model.set_public_key(self.public_key) self.host_model.set_guest_public_key(self.guest_public_key) self.host_model.set_private_key(self.private_key) start_time = time.time() while self.n_iter_ < self.max_iter: # Stage 1: compute and encrypt components (using host public key) required by guest to # calculate gradients and loss. LOGGER.debug("@ Stage 1: ") host_comp = self.host_model.send_components() LOGGER.debug("send enc host_comp: " + create_shape_msg(host_comp)) self._do_remote( host_comp, name=self.transfer_variable.host_component_list.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.host_component_list, self.n_iter_), role=consts.GUEST, idx=-1) # Stage 2: receive guest components in encrypted form (encrypted by guest public key), # and calculate host gradients in encrypted form (encrypted by guest public key), # and send them to guest for decryption LOGGER.debug("@ Stage 2: ") guest_comp = self._do_get( name=self.transfer_variable.guest_component_list.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.guest_component_list, self.n_iter_), idx=-1)[0] LOGGER.debug("receive enc guest_comp: " + create_shape_msg(guest_comp)) self.host_model.receive_components(guest_comp) self._precompute() # calculate host gradients in encrypted form (encrypted by guest public key) encrypt_host_gradients = self.host_model.send_gradients() LOGGER.debug("send encrypt_guest_gradients: " + create_shape_msg(encrypt_host_gradients)) # add random mask to encrypt_host_gradients and send them to guest for decryption masked_enc_host_gradients, gradients_masks = add_random_mask_for_list_of_values( encrypt_host_gradients) LOGGER.debug("send masked_enc_host_gradients: " + create_shape_msg(masked_enc_host_gradients)) self._do_remote( masked_enc_host_gradients, name=self.transfer_variable.masked_enc_host_gradients.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.masked_enc_host_gradients, self.n_iter_), role=consts.GUEST, idx=-1) # Stage 3: receive and then decrypt masked encrypted guest gradients and masked encrypted guest loss, # and send them to guest LOGGER.debug("@ Stage 3: ") masked_enc_guest_gradients = self._do_get( name=self.transfer_variable.masked_enc_guest_gradients.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.masked_enc_guest_gradients, self.n_iter_), idx=-1)[0] masked_enc_guest_loss = self._do_get( name=self.transfer_variable.masked_enc_loss.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.masked_enc_loss, self.n_iter_), idx=-1)[0] masked_dec_guest_gradients = self.__decrypt_gradients( masked_enc_guest_gradients) masked_dec_guest_loss = self.__decrypt_loss(masked_enc_guest_loss) LOGGER.debug("send masked_dec_guest_gradients: " + create_shape_msg(masked_dec_guest_gradients)) self._do_remote( masked_dec_guest_gradients, name=self.transfer_variable.masked_dec_guest_gradients.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.masked_dec_guest_gradients, self.n_iter_), role=consts.GUEST, idx=-1) LOGGER.debug("send masked_dec_guest_loss: " + str(masked_dec_guest_loss)) self._do_remote(masked_dec_guest_loss, name=self.transfer_variable.masked_dec_loss.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.masked_dec_loss, self.n_iter_), role=consts.GUEST, idx=-1) # Stage 4: receive masked but decrypted host gradients from guest and remove mask, # and update host model parameters using these gradients. LOGGER.debug("@ Stage 4: ") masked_dec_host_gradients = self._do_get( name=self.transfer_variable.masked_dec_host_gradients.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.masked_dec_host_gradients, self.n_iter_), idx=-1)[0] LOGGER.debug("receive masked_dec_host_gradients: " + create_shape_msg(masked_dec_host_gradients)) cleared_dec_host_gradients = remove_random_mask_from_list_of_values( masked_dec_host_gradients, gradients_masks) # update host model parameters using these gradients. self.host_model.receive_gradients(cleared_dec_host_gradients) # Stage 5: determine whether training is terminated. LOGGER.debug("@ Stage 5: ") is_stop = self._do_get( name=self.transfer_variable.is_decentralized_enc_ftl_stopped. name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.is_decentralized_enc_ftl_stopped, self.n_iter_), idx=-1)[0] LOGGER.info("@ time: " + str(time.time()) + ", ep: " + str(self.n_iter_) + ", converged: " + str(is_stop)) self.n_iter_ += 1 if is_stop: break end_time = time.time() LOGGER.info("@ running time: " + str(end_time - start_time))
def fit(self, guest_data): LOGGER.info("@ start guest fit") public_key = self._do_get( name=self.transfer_variable.paillier_pubkey.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.paillier_pubkey), idx=-1)[0] guest_x, overlap_indexes, non_overlap_indexes, guest_y = self.prepare_data( guest_data) LOGGER.debug("guest_x: " + str(guest_x.shape)) LOGGER.debug("guest_y: " + str(guest_y.shape)) LOGGER.debug("overlap_indexes: " + str(len(overlap_indexes))) LOGGER.debug("non_overlap_indexes: " + str(len(non_overlap_indexes))) self.guest_model.set_batch(guest_x, guest_y, non_overlap_indexes, overlap_indexes) self.guest_model.set_public_key(public_key) start_time = time.time() while self.n_iter_ < self.max_iter: guest_comp = self.guest_model.send_components() LOGGER.debug("send guest_comp: " + create_shape_msg(guest_comp)) self._do_remote( guest_comp, name=self.transfer_variable.guest_component_list.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.guest_component_list, self.n_iter_), role=consts.HOST, idx=-1) host_comp = self._do_get( name=self.transfer_variable.host_component_list.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.host_component_list, self.n_iter_), idx=-1)[0] LOGGER.debug("receive host_comp: " + create_shape_msg(host_comp)) self.guest_model.receive_components(host_comp) self._precompute() encrypt_guest_gradients = self.guest_model.send_gradients() LOGGER.debug("send encrypt_guest_gradients: " + create_shape_msg(encrypt_guest_gradients)) self._do_remote( encrypt_guest_gradients, name=self.transfer_variable.encrypt_guest_gradient.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.encrypt_guest_gradient, self.n_iter_), role=consts.ARBITER, idx=-1) decrypt_guest_gradients = self._do_get( name=self.transfer_variable.decrypt_guest_gradient.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.decrypt_guest_gradient, self.n_iter_), idx=-1)[0] LOGGER.debug("receive decrypt_guest_gradients: " + create_shape_msg(decrypt_guest_gradients)) self.guest_model.receive_gradients(decrypt_guest_gradients) encrypt_loss = self.guest_model.send_loss() self._do_remote(encrypt_loss, name=self.transfer_variable.encrypt_loss.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.encrypt_loss, self.n_iter_), role=consts.ARBITER, idx=-1) is_stop = self._do_get( name=self.transfer_variable.is_encrypted_ftl_stopped.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.is_encrypted_ftl_stopped, self.n_iter_), idx=-1)[0] LOGGER.info("@ time: " + str(time.time()) + ", ep: " + str(self.n_iter_) + ", converged:" + str(is_stop)) self.n_iter_ += 1 if is_stop: break end_time = time.time() LOGGER.info("@ running time: " + str(end_time - start_time))