def prepare_data(self, intersect_obj, data_inst, guest_side=False): """ find intersect ids and prepare dataloader """ if guest_side: data_inst = self.check_label(data_inst) overlap_samples = intersect_obj.run(data_inst) # find intersect ids non_overlap_samples = data_inst.subtractByKey(overlap_samples) LOGGER.debug('num of overlap/non-overlap sampels: {}/{}'.format(overlap_samples.count(), non_overlap_samples.count())) if overlap_samples.count() == 0: raise ValueError('no overlap samples') if guest_side and non_overlap_samples == 0: raise ValueError('overlap samples are required in guest side') self.store_header = data_inst.schema['header'] LOGGER.debug('data inst header is {}'.format(self.store_header)) LOGGER.debug('has {} overlap samples'.format(overlap_samples.count())) batch_size = self.batch_size if self.batch_size == -1: batch_size = data_inst.count() + 1 # make sure larger than sample number data_loader = FTLDataLoader(non_overlap_samples=non_overlap_samples, batch_size=batch_size, overlap_samples=overlap_samples, guest_side=guest_side) LOGGER.debug("data details are :{}".format(data_loader.data_basic_info())) return data_loader, data_loader.x_shape, data_inst.count(), len(data_loader.get_overlap_indexes())
def batch_compute_components(self, data_loader: FTLDataLoader): """ compute guest components """ phi, overlap_ua = self.compute_phi_and_overlap_ua( data_loader) # Φ_A [1, feature_dim] phi_product = np.matmul(phi.transpose(), phi) # (Φ_A)‘(Φ_A) [feature_dim, feature_dim] if self.overlap_y is None: self.overlap_y = data_loader.get_overlap_y( ) # {C(y)=y} [1, feat_dim] if self.overlap_y_2 is None: self.overlap_y_2 = self.overlap_y * self.overlap_y # {D(y)=y^2} # [1, feat_dim] overlap_ua = np.concatenate(overlap_ua, axis=0) # [overlap_num, feat_dim] # 3 components will be sent to host y_overlap_2_phi_2 = 0.25 * np.expand_dims(self.overlap_y_2, axis=2) * phi_product y_overlap_phi = -0.5 * self.overlap_y * phi mapping_comp_a = -overlap_ua * self.constant_k return phi, phi_product, overlap_ua, [ y_overlap_2_phi_2, y_overlap_phi, mapping_comp_a ]
def compute_phi_and_overlap_ua(self, data_loader: FTLDataLoader): """ compute Φ and ua of overlap samples """ phi = None # [1, feature_dim] Φ_A overlap_ua = [] for i in range(len(data_loader)): batch_x, batch_y = data_loader[i] ua_batch = self.nn.predict(batch_x) # [batch_size, feature_dim] relative_overlap_index = data_loader.get_relative_overlap_index(i) if len(relative_overlap_index) != 0: if self.verbose: LOGGER.debug('batch {}/{} overlap index is {}'.format( i, len(data_loader), relative_overlap_index)) overlap_ua.append(ua_batch[relative_overlap_index]) phi_tmp = np.expand_dims(np.sum(batch_y * ua_batch, axis=0), axis=0) if phi is None: phi = phi_tmp else: phi += phi_tmp phi = phi / self.data_num return phi, overlap_ua
def update_nn_weights(self, backward_grads, data_loader: FTLDataLoader, epoch_idx, decay=False): """ updating bottom nn model weights using backward gradients """ LOGGER.debug('updating grads at epoch {}'.format(epoch_idx)) assert len(data_loader.x) == len(backward_grads) weight_grads = [] for i in range(len(data_loader)): start, end = data_loader.get_batch_indexes(i) batch_x = data_loader.x[start:end] batch_grads = backward_grads[start:end] batch_weight_grads = self._get_mini_batch_gradient( batch_x, batch_grads) if len(weight_grads) == 0: weight_grads.extend(batch_weight_grads) else: for w, bw in zip(weight_grads, batch_weight_grads): w += bw if decay: new_learning_rate = self.learning_rate_decay( self.learning_rate, epoch_idx) self.nn.set_learning_rate(new_learning_rate) LOGGER.debug('epoch {} optimizer details are {}'.format( epoch_idx, self.nn.export_optimizer_config())) self.nn.apply_gradients(weight_grads)
def compute_backward_gradients(self, host_components, data_loader: FTLDataLoader, epoch_idx, local_round=-1): """ compute backward gradients using host components """ # they are Paillier tensors or np array overlap_ub, overlap_ub_2, mapping_comp_b = host_components[0], host_components[1], host_components[2] y_overlap_2_phi = np.expand_dims(self.overlap_y_2 * self.phi, axis=1) if self.mode == 'plain': loss_grads_const_part1 = 0.25 * np.squeeze(np.matmul(y_overlap_2_phi, overlap_ub_2), axis=1) loss_grads_const_part2 = self.overlap_y * overlap_ub const = np.sum(loss_grads_const_part1, axis=0) - 0.5 * np.sum(loss_grads_const_part2, axis=0) grad_a_nonoverlap = self.alpha * const * data_loader.y[data_loader.get_non_overlap_indexes()] / self.data_num grad_a_overlap = self.alpha * const * self.overlap_y / self.data_num + mapping_comp_b return np.concatenate([grad_a_overlap, grad_a_nonoverlap], axis=0) elif self.mode == 'encrypted': loss_grads_const_part1 = overlap_ub_2.matmul_3d(0.25 * y_overlap_2_phi, multiply='right') loss_grads_const_part1 = loss_grads_const_part1.squeeze(axis=1) if self.overlap_y_pt is None: self.overlap_y_pt = PaillierTensor(self.overlap_y, partitions=self.partitions) loss_grads_const_part2 = overlap_ub * self.overlap_y_pt encrypted_const = loss_grads_const_part1.reduce_sum() - 0.5 * loss_grads_const_part2.reduce_sum() grad_a_overlap = self.overlap_y_pt.map_ndarray_product((self.alpha/self.data_num * encrypted_const)) + mapping_comp_b const, grad_a_overlap = self.decrypt_inter_result(encrypted_const, grad_a_overlap, epoch_idx=epoch_idx , local_round=local_round) self.decrypt_host_data(epoch_idx, local_round=local_round) grad_a_nonoverlap = self.alpha * const * data_loader.y[data_loader.get_non_overlap_indexes()]/self.data_num return np.concatenate([grad_a_overlap.numpy(), grad_a_nonoverlap], axis=0)