class Guest(batch_info_sync.Guest): def __init__(self): self.mini_batch_obj = None self.finish_sycn = False self.batch_nums = None def register_batch_generator(self, transfer_variables, has_arbiter=True): self._register_batch_data_index_transfer( transfer_variables.batch_info, transfer_variables.batch_data_index, has_arbiter) def initialize_batch_generator(self, data_instances, batch_size, suffix=tuple()): self.mini_batch_obj = MiniBatch(data_instances, batch_size=batch_size) self.batch_nums = self.mini_batch_obj.batch_nums batch_info = {"batch_size": batch_size, "batch_num": self.batch_nums} self.sync_batch_info(batch_info, suffix) index_generator = self.mini_batch_obj.mini_batch_data_generator( result='index') batch_index = 0 for batch_data_index in index_generator: batch_suffix = suffix + (batch_index, ) self.sync_batch_index(batch_data_index, batch_suffix) batch_index += 1 def generate_batch_data(self): data_generator = self.mini_batch_obj.mini_batch_data_generator( result='data') for batch_data in data_generator: yield batch_data
def test_mini_batch_data_generator(self, data_num=100, batch_size=320): t0 = time.time() feature_num = 20 expect_batches = data_num // batch_size # print("expect_batches: {}".format(expect_batches)) data_instances = self.prepare_data(data_num=data_num, feature_num=feature_num) # print("Prepare data time: {}".format(time.time() - t0)) mini_batch_obj = MiniBatch(data_inst=data_instances, batch_size=batch_size) batch_data_generator = mini_batch_obj.mini_batch_data_generator() batch_id = 0 pre_time = time.time() - t0 # print("Prepare mini batch time: {}".format(pre_time)) total_num = 0 for batch_data in batch_data_generator: batch_num = batch_data.count() if batch_id < expect_batches - 1: # print("In mini batch test, batch_num: {}, batch_size:{}".format( # batch_num, batch_size # )) self.assertEqual(batch_num, batch_size) batch_id += 1 total_num += batch_num # curt_time = time.time() # print("One batch time: {}".format(curt_time - pre_time)) # pre_time = curt_time self.assertEqual(total_num, data_num)
def initialize_batch_generator(self, data_instances, batch_size, suffix=tuple(), shuffle=False, batch_strategy="full", masked_rate=0): self.mini_batch_obj = MiniBatch(data_instances, batch_size=batch_size, shuffle=shuffle, batch_strategy=batch_strategy, masked_rate=masked_rate) self.batch_nums = self.mini_batch_obj.batch_nums self.batch_masked = self.mini_batch_obj.batch_size != self.mini_batch_obj.masked_batch_size batch_info = {"batch_size": self.mini_batch_obj.batch_size, "batch_num": self.batch_nums, "batch_mutable": self.mini_batch_obj.batch_mutable, "masked_batch_size": self.mini_batch_obj.masked_batch_size} self.sync_batch_info(batch_info, suffix) if not self.mini_batch_obj.batch_mutable: self.prepare_batch_data(suffix)
def initialize_batch_generator(self, data_instances, batch_size, suffix=tuple()): self.mini_batch_obj = MiniBatch(data_instances, batch_size=batch_size) self.batch_nums = self.mini_batch_obj.batch_nums batch_info = {"batch_size": batch_size, "batch_num": self.batch_nums} self.sync_batch_info(batch_info, suffix) index_generator = self.mini_batch_obj.mini_batch_data_generator( result='index') batch_index = 0 for batch_data_index in index_generator: batch_suffix = suffix + (batch_index, ) self.sync_batch_index(batch_data_index, batch_suffix) batch_index += 1
def __init_parameters(self, data_instances): party_weight_id = self.transfer_variable.generate_transferid( self.transfer_variable.host_party_weight ) # LOGGER.debug("party_weight_id: {}".format(party_weight_id)) federation.remote(self.party_weight, name=self.transfer_variable.host_party_weight.name, tag=party_weight_id, role=consts.ARBITER, idx=0) self.__synchronize_encryption() # Send re-encrypt times self.mini_batch_obj = MiniBatch(data_inst=data_instances, batch_size=self.batch_size) if self.use_encrypt: # LOGGER.debug("Use encryption, send re_encrypt_times") total_batch_num = self.mini_batch_obj.batch_nums re_encrypt_times = total_batch_num // self.re_encrypt_batches transfer_id = self.transfer_variable.generate_transferid(self.transfer_variable.re_encrypt_times) federation.remote(re_encrypt_times, name=self.transfer_variable.re_encrypt_times.name, tag=transfer_id, role=consts.ARBITER, idx=0) LOGGER.info("sent re_encrypt_times: {}".format(re_encrypt_times))
def fit(self, data_instances, validate_data=None): self._abnormal_detection(data_instances) self.init_schema(data_instances) validation_strategy = self.init_validation_strategy( data_instances, validate_data) self.model_weights = self._init_model_variables(data_instances) max_iter = self.max_iter total_data_num = data_instances.count() mini_batch_obj = MiniBatch(data_inst=data_instances, batch_size=self.batch_size) model_weights = self.model_weights self.__synchronize_encryption() self.zcl_idx, self.zcl_num_party = self.transfer_variable.num_party.get( idx=0, suffix=('train', )) LOGGER.debug("party num:" + str(self.zcl_num_party)) self.__init_model() self.train_loss_results = [] self.train_accuracy_results = [] self.test_loss_results = [] self.test_accuracy_results = [] for iter_num in range(self.max_iter): total_loss = 0 batch_num = 0 epoch_train_loss_avg = tfe.metrics.Mean() epoch_train_accuracy = tfe.metrics.Accuracy() for train_x, train_y in self.zcl_dataset: LOGGER.info("Staring batch {}".format(batch_num)) start_t = time.time() loss_value, grads = self.__grad(self.zcl_model, train_x, train_y) loss_value = loss_value.numpy() grads = [x.numpy() for x in grads] LOGGER.info("Start encrypting") loss_value = batch_encryption.encrypt( self.zcl_encrypt_operator.get_public_key(), loss_value) grads = [ batch_encryption.encrypt_matrix( self.zcl_encrypt_operator.get_public_key(), x) for x in grads ] grads = Gradients(grads) LOGGER.info("Finish encrypting") # grads = self.encrypt_operator.get_public_key() self.transfer_variable.guest_grad.remote( obj=grads.for_remote(), role=consts.ARBITER, idx=0, suffix=(iter_num, batch_num)) LOGGER.info("Sent grads") self.transfer_variable.guest_loss.remote(obj=loss_value, role=consts.ARBITER, idx=0, suffix=(iter_num, batch_num)) LOGGER.info("Sent loss") sum_grads = self.transfer_variable.aggregated_grad.get( idx=0, suffix=(iter_num, batch_num)) LOGGER.info("Got grads") sum_loss = self.transfer_variable.aggregated_loss.get( idx=0, suffix=(iter_num, batch_num)) LOGGER.info("Got loss") sum_loss = batch_encryption.decrypt( self.zcl_encrypt_operator.get_privacy_key(), sum_loss) sum_grads = [ batch_encryption.decrypt_matrix( self.zcl_encrypt_operator.get_privacy_key(), x).astype(np.float32) for x in sum_grads.unboxed ] LOGGER.info("Finish decrypting") # sum_grads = np.array(sum_grads) / self.zcl_num_party self.zcl_optimizer.apply_gradients( zip(sum_grads, self.zcl_model.trainable_variables), self.zcl_global_step) elapsed_time = time.time() - start_t # epoch_train_loss_avg(loss_value) # epoch_train_accuracy(tf.argmax(self.zcl_model(train_x), axis=1, output_type=tf.int32), # train_y) self.train_loss_results.append(sum_loss) train_accuracy_v = accuracy_score( train_y, tf.argmax(self.zcl_model(train_x), axis=1, output_type=tf.int32)) self.train_accuracy_results.append(train_accuracy_v) test_loss_v = self.__loss(self.zcl_model, self.zcl_x_test, self.zcl_y_test) self.test_loss_results.append(test_loss_v) test_accuracy_v = accuracy_score( self.zcl_y_test, tf.argmax(self.zcl_model(self.zcl_x_test), axis=1, output_type=tf.int32)) self.test_accuracy_results.append(test_accuracy_v) LOGGER.info( "Epoch {:03d}, iteration {:03d}: train_loss: {:.3f}, train_accuracy: {:.3%}, test_loss: {:.3f}, " "test_accuracy: {:.3%}, elapsed_time: {:.4f}".format( iter_num, batch_num, sum_loss, train_accuracy_v, test_loss_v, test_accuracy_v, elapsed_time)) batch_num += 1 if batch_num >= self.zcl_early_stop_batch: return self.n_iter_ = iter_num
def fit(self, data_instances, validate_data=None): self._abnormal_detection(data_instances) self.check_abnormal_values(data_instances) self.init_schema(data_instances) validation_strategy = self.init_validation_strategy( data_instances, validate_data) self.model_weights = self._init_model_variables(data_instances) max_iter = self.max_iter # total_data_num = data_instances.count() mini_batch_obj = MiniBatch(data_inst=data_instances, batch_size=self.batch_size) model_weights = self.model_weights degree = 0 self.prev_round_weights = copy.deepcopy(model_weights) while self.n_iter_ < max_iter + 1: batch_data_generator = mini_batch_obj.mini_batch_data_generator() self.optimizer.set_iters(self.n_iter_) if ((self.n_iter_ + 1) % self.aggregate_iters == 0) or self.n_iter_ == max_iter: weight = self.aggregator.aggregate_then_get( model_weights, degree=degree, suffix=self.n_iter_) self.model_weights = LogisticRegressionWeights( weight.unboxed, self.fit_intercept) # store prev_round_weights after aggregation self.prev_round_weights = copy.deepcopy(self.model_weights) # send loss to arbiter loss = self._compute_loss(data_instances, self.prev_round_weights) self.aggregator.send_loss(loss, degree=degree, suffix=(self.n_iter_, )) degree = 0 self.is_converged = self.aggregator.get_converge_status( suffix=(self.n_iter_, )) LOGGER.info( "n_iters: {}, loss: {} converge flag is :{}".format( self.n_iter_, loss, self.is_converged)) if self.is_converged or self.n_iter_ == max_iter: break model_weights = self.model_weights batch_num = 0 for batch_data in batch_data_generator: n = batch_data.count() # LOGGER.debug("In each batch, lr_weight: {}, batch_data count: {}".format(model_weights.unboxed, n)) f = functools.partial(self.gradient_operator.compute_gradient, coef=model_weights.coef_, intercept=model_weights.intercept_, fit_intercept=self.fit_intercept) grad = batch_data.applyPartitions(f).reduce( fate_operator.reduce_add) grad /= n # LOGGER.debug('iter: {}, batch_index: {}, grad: {}, n: {}'.format( # self.n_iter_, batch_num, grad, n)) if self.use_proximal: # use proximal term model_weights = self.optimizer.update_model( model_weights, grad=grad, has_applied=False, prev_round_weights=self.prev_round_weights) else: model_weights = self.optimizer.update_model( model_weights, grad=grad, has_applied=False) batch_num += 1 degree += n validation_strategy.validate(self, self.n_iter_) self.n_iter_ += 1 self.set_summary(self.get_model_summary())
def fit(self, data_instances, node2id, local_instances=None, common_nodes=None): """ Train node embedding for role guest Parameters ---------- data_instances: DTable of target node and label, input data node2id: a dict which can map node name to id """ LOGGER.info("samples number:{}".format(data_instances.count())) LOGGER.info("Enter network embedding procedure:") self.n_node = len(node2id) LOGGER.info("Bank A has {} nodes".format(self.n_node)) data_instances = data_instances.mapValues(HeteroNEGuest.load_data) LOGGER.info("Transform input data to train instance") public_key = federation.get( name=self.transfer_variable.paillier_pubkey.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.paillier_pubkey), idx=0) LOGGER.info("Get public_key from arbiter:{}".format(public_key)) self.encrypt_operator.set_public_key(public_key) # hetero network embedding LOGGER.info("Generate mini-batch from input data") mini_batch_obj = MiniBatch(data_instances, batch_size=self.batch_size) batch_num = mini_batch_obj.batch_nums LOGGER.info("samples number:{}".format(data_instances.count())) if self.batch_size == -1: LOGGER.info( "batch size is -1, set it to the number of data in data_instances" ) self.batch_size = data_instances.count() ############## # horizontal federated learning LOGGER.info("Generate mini-batch for local instances in guest") mini_batch_obj_local = MiniBatch(local_instances, batch_size=self.batch_size) local_batch_num = mini_batch_obj_local.batch_nums common_node_instances = eggroll.parallelize( ((node, node) for node in common_nodes), include_key=True, name='common_nodes') ############## batch_info = {'batch_size': self.batch_size, "batch_num": batch_num} LOGGER.info("batch_info:{}".format(batch_info)) federation.remote(batch_info, name=self.transfer_variable.batch_info.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.batch_info), role=consts.HOST, idx=0) LOGGER.info("Remote batch_info to Host") federation.remote(batch_info, name=self.transfer_variable.batch_info.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.batch_info), role=consts.ARBITER, idx=0) LOGGER.info("Remote batch_info to Arbiter") self.encrypted_calculator = [ EncryptModeCalculator( self.encrypt_operator, self.encrypted_mode_calculator_param.mode, self.encrypted_mode_calculator_param.re_encrypted_rate) for _ in range(batch_num) ] LOGGER.info("Start initialize model.") self.embedding_ = self.initializer.init_model((self.n_node, self.dim), self.init_param_obj) LOGGER.info("Embedding shape={}".format(self.embedding_.shape)) is_send_all_batch_index = False self.n_iter_ = 0 index_data_inst_map = {} while self.n_iter_ < self.max_iter: LOGGER.info("iter:{}".format(self.n_iter_)) ################# local_batch_data_generator = mini_batch_obj_local.mini_batch_data_generator( ) total_loss = 0 local_batch_num = 0 LOGGER.info("Enter the horizontally federated learning procedure:") for local_batch_data in local_batch_data_generator: n = local_batch_data.count() #LOGGER.info("Local batch data count:{}".format(n)) E_Y = self.compute_local_embedding(local_batch_data, self.embedding_, node2id) local_grads_e1, local_grads_e2, local_loss = self.local_gradient_operator.compute( E_Y, 'E_1') local_grads_e1 = local_grads_e1.mapValues( lambda g: self.local_optimizer.apply_gradients(g / n)) local_grads_e2 = local_grads_e2.mapValues( lambda g: self.local_optimizer.apply_gradients(g / n)) e1id_join_grads = local_batch_data.join( local_grads_e1, lambda v, g: (node2id[v[0]], g)) e2id_join_grads = local_batch_data.join( local_grads_e2, lambda v, g: (node2id[v[1]], g)) self.update_model(e1id_join_grads) self.update_model(e2id_join_grads) local_loss = local_loss / n local_batch_num += 1 total_loss += local_loss #LOGGER.info("gradient count:{}".format(e1id_join_grads.count())) guest_common_embedding = common_node_instances.mapValues( lambda node: self.embedding_[node2id[node]]) federation.remote( guest_common_embedding, name=self.transfer_variable.guest_common_embedding.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.guest_common_embedding, self.n_iter_, 0), role=consts.ARBITER, idx=0) LOGGER.info("Remote the embedding of common node to arbiter!") common_embedding = federation.get( name=self.transfer_variable.common_embedding.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.common_embedding, self.n_iter_, 0), idx=0) LOGGER.info( "Get the aggregated embedding of common node from arbiter!") self.update_common_nodes(common_embedding, common_nodes, node2id) total_loss /= local_batch_num LOGGER.info( "Iter {}, horizontally feaderated learning loss: {}".format( self.n_iter_, total_loss)) ################# # verticallly feaderated learning # each iter will get the same batch_data_generator LOGGER.info("Enter the vertically federated learning:") batch_data_generator = mini_batch_obj.mini_batch_data_generator( result='index') batch_index = 0 for batch_data_index in batch_data_generator: LOGGER.info("batch:{}".format(batch_index)) # only need to send one times if not is_send_all_batch_index: LOGGER.info("remote mini-batch index to Host") federation.remote( batch_data_index, name=self.transfer_variable.batch_data_index.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.batch_data_index, self.n_iter_, batch_index), role=consts.HOST, idx=0) if batch_index >= mini_batch_obj.batch_nums - 1: is_send_all_batch_index = True # in order to avoid joining in next iteration # Get mini-batch train data if len(index_data_inst_map) < batch_num: batch_data_inst = data_instances.join( batch_data_index, lambda data_inst, index: data_inst) index_data_inst_map[batch_index] = batch_data_inst else: batch_data_inst = index_data_inst_map[batch_index] # For inductive learning: transform node attributes to node embedding # self.transform(batch_data_inst) self.guest_forward = self.compute_forward( batch_data_inst, self.embedding_, node2id, batch_index) host_forward = federation.get( name=self.transfer_variable.host_forward_dict.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.host_forward_dict, self.n_iter_, batch_index), idx=0) LOGGER.info("Get host_forward from host") aggregate_forward_res = self.aggregate_forward(host_forward) en_aggregate_ee = aggregate_forward_res.mapValues( lambda v: v[0]) en_aggregate_ee_square = aggregate_forward_res.mapValues( lambda v: v[1]) # compute [[d]] if self.gradient_operator is None: self.gradient_operator = HeteroNetworkEmbeddingGradient( self.encrypt_operator) fore_gradient = self.gradient_operator.compute_fore_gradient( batch_data_inst, en_aggregate_ee) host_gradient = self.gradient_operator.compute_gradient( self.guest_forward.mapValues( lambda v: Instance(features=v[1])), fore_gradient) federation.remote( host_gradient, name=self.transfer_variable.host_gradient.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.host_gradient, self.n_iter_, batch_index), role=consts.ARBITER, idx=0) LOGGER.info("Remote host_gradient to arbiter") composed_data_inst = host_forward.join( batch_data_inst, lambda hf, d: Instance(features=hf[1], label=d.label)) guest_gradient, loss = self.gradient_operator.compute_gradient_and_loss( composed_data_inst, fore_gradient, en_aggregate_ee, en_aggregate_ee_square) federation.remote( guest_gradient, name=self.transfer_variable.guest_gradient.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.guest_gradient, self.n_iter_, batch_index), role=consts.ARBITER, idx=0) LOGGER.info("Remote guest_gradient to arbiter") optim_guest_gradient = federation.get( name=self.transfer_variable.guest_optim_gradient.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.guest_optim_gradient, self.n_iter_, batch_index), idx=0) LOGGER.info("Get optim_guest_gradient from arbiter") # update node embedding LOGGER.info("Update node embedding") nodeid_join_gradient = batch_data_inst.join( optim_guest_gradient, lambda instance, gradient: (node2id[instance.features], gradient)) self.update_model(nodeid_join_gradient) # update local model that transform attribute to node embedding training_info = { 'iteration': self.n_iter_, 'batch_index': batch_index } self.update_local_model(fore_gradient, batch_data_inst, self.embedding_, **training_info) # loss need to be encrypted !!!!!! federation.remote( loss, name=self.transfer_variable.loss.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.loss, self.n_iter_, batch_index), role=consts.ARBITER, idx=0) LOGGER.info("Remote loss to arbiter") # is converge of loss in arbiter batch_index += 1 # remove temporary resource rubbish_list = [ host_forward, aggregate_forward_res, en_aggregate_ee, en_aggregate_ee_square, fore_gradient, self.guest_forward ] rubbish_clear(rubbish_list) ########## guest_common_embedding = common_node_instances.mapValues( lambda node: self.embedding_[node2id[node]]) federation.remote( guest_common_embedding, name=self.transfer_variable.guest_common_embedding.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.guest_common_embedding, self.n_iter_, 1), role=consts.ARBITER, idx=0) common_embedding = federation.get( name=self.transfer_variable.common_embedding.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.common_embedding, self.n_iter_, 1), idx=0) self.update_common_nodes(common_embedding, common_nodes, node2id) ########## is_stopped = federation.get( name=self.transfer_variable.is_stopped.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.is_stopped, self.n_iter_), idx=0) LOGGER.info("Get is_stop flag from arbiter:{}".format(is_stopped)) self.n_iter_ += 1 if is_stopped: LOGGER.info( "Get stop signal from arbiter, model is converged, iter:{}" .format(self.n_iter_)) break embedding_table = eggroll.table(name='guest', namespace='node_embedding', partition=10) id2node = dict(zip(node2id.values(), node2id.keys())) for id, embedding in enumerate(self.embedding_): embedding_table.put(id2node[id], embedding) embedding_table.save_as(name='guest', namespace='node_embedding', partition=10) LOGGER.info("Reach max iter {}, train model finish!".format( self.max_iter))
def fit(self, data_instances, validate_data=None): LOGGER.debug("Start data count: {}".format(data_instances.count())) self._abnormal_detection(data_instances) self.init_schema(data_instances) validation_strategy = self.init_validation_strategy( data_instances, validate_data) self.model_weights = self._init_model_variables(data_instances) LOGGER.debug("After init, model_weights: {}".format( self.model_weights.unboxed)) mini_batch_obj = MiniBatch(data_inst=data_instances, batch_size=self.batch_size) model_weights = self.model_weights degree = 0 while self.n_iter_ < self.max_iter + 1: LOGGER.info("iter:{}".format(self.n_iter_)) batch_data_generator = mini_batch_obj.mini_batch_data_generator() if (self.n_iter_ > 0 and self.n_iter_ % self.aggregate_iters == 0) or self.n_iter_ == self.max_iter: weight = self.aggregator.aggregate_then_get( weight, degree=degree, suffix=self.n_iter_) weight._weights = np.array(weight._weights) # This weight is transferable Weight, should get the parameter back self.model_weights.update(weight) LOGGER.debug( "Before aggregate: {}, degree: {} after aggregated: {}". format(model_weights.unboxed / degree, degree, self.model_weights.unboxed)) loss = self._compute_loss(data_instances) self.aggregator.send_loss(loss, degree=degree, suffix=(self.n_iter_, )) degree = 0 self.is_converged = self.aggregator.get_converge_status( suffix=(self.n_iter_, )) LOGGER.info("n_iters: {}, is_converge: {}".format( self.n_iter_, self.is_converged)) if self.is_converged: break model_weights = self.model_weights batch_num = 0 for batch_data in batch_data_generator: n = batch_data.count() LOGGER.debug("before compute_gradient,w_:{},embed_:{}".format( model_weights.w_, model_weights.embed_)) f = functools.partial(self.gradient_operator.compute_gradient, w=model_weights.w_, embed=model_weights.embed_, intercept=model_weights.intercept_, fit_intercept=self.fit_intercept) grad = batch_data.mapPartitions(f).reduce( fate_operator.reduce_add) grad /= n weight = self.optimizer.update_model(model_weights, grad, has_applied=False) weight._weights = np.array(weight._weights) model_weights.update(weight) batch_num += 1 degree += n validation_strategy.validate(self, self.n_iter_) self.n_iter_ += 1 LOGGER.info("Finish Training task, total iters: {}".format( self.n_iter_))
def fit(self, data_instances, validate_data=None): self._abnormal_detection(data_instances) self.init_schema(data_instances) validation_strategy = self.init_validation_strategy( data_instances, validate_data) self.model_weights = self._init_model_variables(data_instances) mini_batch_obj = MiniBatch(data_inst=data_instances, batch_size=self.batch_size) model_weights = self.model_weights degree = 0 while self.n_iter_ < self.max_iter + 1: LOGGER.info("iter:{}".format(self.n_iter_)) batch_data_generator = mini_batch_obj.mini_batch_data_generator() self.optimizer.set_iters(self.n_iter_) if (self.n_iter_ > 0 and self.n_iter_ % self.aggregate_iters == 0) or self.n_iter_ == self.max_iter: # This loop will run after weight has been created,weight will be in LRweights weight = self.aggregator.aggregate_then_get( weight, degree=degree, suffix=self.n_iter_) weight._weights = np.array(weight._weights) # This weight is transferable Weight, should get the parameter back self.model_weights.update(weight) LOGGER.debug( "Before aggregate: {}, degree: {} after aggregated: {}". format(model_weights.unboxed / degree, degree, self.model_weights.unboxed)) loss = self._compute_loss(data_instances) self.aggregator.send_loss(loss, degree=degree, suffix=(self.n_iter_, )) degree = 0 self.is_converged = self.aggregator.get_converge_status( suffix=(self.n_iter_, )) LOGGER.info( "n_iters: {}, loss: {} converge flag is :{}".format( self.n_iter_, loss, self.is_converged)) if self.is_converged: break model_weights = self.model_weights batch_num = 0 for batch_data in batch_data_generator: n = batch_data.count() LOGGER.debug( "In each batch, fm_weight: {}, batch_data count: {},w:{},embed:{}" .format(model_weights.unboxed, n, model_weights.w_, model_weights.embed_)) f = functools.partial(self.gradient_operator.compute_gradient, w=model_weights.w_, embed=model_weights.embed_, intercept=model_weights.intercept_, fit_intercept=self.fit_intercept) grad = batch_data.mapPartitions(f).reduce( fate_operator.reduce_add) grad /= n LOGGER.debug( 'iter: {}, batch_index: {}, grad: {}, n: {}'.format( self.n_iter_, batch_num, grad, n)) weight = self.optimizer.update_model(model_weights, grad, has_applied=False) weight._weights = np.array(weight._weights) model_weights.update(weight) batch_num += 1 degree += n validation_strategy.validate(self, self.n_iter_) self.n_iter_ += 1
def fit(self, data_instances): LOGGER.info("Enter hetero_lr_guest fit") data_instances = data_instances.mapValues(HeteroLRGuest.load_data) public_key = federation.get( name=self.transfer_variable.paillier_pubkey.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.paillier_pubkey), idx=0) LOGGER.info("Get public_key from arbiter:{}".format(public_key)) self.encrypt_operator.set_public_key(public_key) LOGGER.info("Generate mini-batch from input data") mini_batch_obj = MiniBatch(data_instances, batch_size=self.batch_size) batch_info = { "batch_size": self.batch_size, "batch_num": mini_batch_obj.batch_nums } LOGGER.info("batch_info:" + str(batch_info)) federation.remote(batch_info, name=self.transfer_variable.batch_info.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.batch_info), role=consts.HOST, idx=0) LOGGER.info("Remote batch_info to Host") federation.remote(batch_info, name=self.transfer_variable.batch_info.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.batch_info), role=consts.ARBITER, idx=0) LOGGER.info("Remote batch_info to Arbiter") LOGGER.info("Start initialize model.") LOGGER.info("fit_intercept:{}".format( self.init_param_obj.fit_intercept)) model_shape = self.get_features_shape(data_instances) weight = self.initializer.init_model(model_shape, init_params=self.init_param_obj) if self.init_param_obj.fit_intercept is True: self.coef_ = weight[:-1] self.intercept_ = weight[-1] else: self.coef_ = weight is_stopped = False is_send_all_batch_index = False self.n_iter_ = 0 while self.n_iter_ < self.max_iter: LOGGER.info("iter:{}".format(self.n_iter_)) batch_data_generator = mini_batch_obj.mini_batch_index_generator( data_inst=data_instances, batch_size=self.batch_size) batch_index = 0 for batch_data_index in batch_data_generator: LOGGER.info("batch:{}".format(batch_index)) if not is_send_all_batch_index: LOGGER.info("remote mini-batch index to Host") federation.remote( batch_data_index, name=self.transfer_variable.batch_data_index.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.batch_data_index, self.n_iter_, batch_index), role=consts.HOST, idx=0) if batch_index >= mini_batch_obj.batch_nums - 1: is_send_all_batch_index = True # Get mini-batch train data batch_data_inst = data_instances.join( batch_data_index, lambda data_inst, index: data_inst) # guest/host forward self.compute_forward(batch_data_inst, self.coef_, self.intercept_) host_forward = federation.get( name=self.transfer_variable.host_forward_dict.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.host_forward_dict, self.n_iter_, batch_index), idx=0) LOGGER.info("Get host_forward from host") aggregate_forward_res = self.aggregate_forward(host_forward) en_aggregate_wx = aggregate_forward_res.mapValues( lambda v: v[0]) en_aggregate_wx_square = aggregate_forward_res.mapValues( lambda v: v[1]) # compute [[d]] if self.gradient_operator is None: self.gradient_operator = HeteroLogisticGradient( self.encrypt_operator) fore_gradient = self.gradient_operator.compute_fore_gradient( batch_data_inst, en_aggregate_wx) federation.remote( fore_gradient, name=self.transfer_variable.fore_gradient.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.fore_gradient, self.n_iter_, batch_index), role=consts.HOST, idx=0) LOGGER.info("Remote fore_gradient to Host") # compute guest gradient and loss guest_gradient, loss = self.gradient_operator.compute_gradient_and_loss( batch_data_inst, fore_gradient, en_aggregate_wx, en_aggregate_wx_square, self.fit_intercept) # loss regulation if necessary if self.updater is not None: guest_loss_regular = self.updater.loss_norm(self.coef_) loss += self.encrypt_operator.encrypt(guest_loss_regular) federation.remote( guest_gradient, name=self.transfer_variable.guest_gradient.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.guest_gradient, self.n_iter_, batch_index), role=consts.ARBITER, idx=0) LOGGER.info("Remote guest_gradient to arbiter") optim_guest_gradient = federation.get( name=self.transfer_variable.guest_optim_gradient.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.guest_optim_gradient, self.n_iter_, batch_index), idx=0) LOGGER.info("Get optim_guest_gradient from arbiter") # update model LOGGER.info("update_model") self.update_model(optim_guest_gradient) # Get loss regulation from Host if regulation is set if self.updater is not None: en_host_loss_regular = federation.get( name=self.transfer_variable.host_loss_regular.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.host_loss_regular, self.n_iter_, batch_index), idx=0) LOGGER.info("Get host_loss_regular from Host") loss += en_host_loss_regular federation.remote( loss, name=self.transfer_variable.loss.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.loss, self.n_iter_, batch_index), role=consts.ARBITER, idx=0) LOGGER.info("Remote loss to arbiter") # is converge of loss in arbiter is_stopped = federation.get( name=self.transfer_variable.is_stopped.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.is_stopped, self.n_iter_, batch_index), idx=0) LOGGER.info( "Get is_stop flag from arbiter:{}".format(is_stopped)) batch_index += 1 if is_stopped: LOGGER.info( "Get stop signal from arbiter, model is converged, iter:{}" .format(self.n_iter_)) break self.n_iter_ += 1 if is_stopped: break LOGGER.info("Reach max iter {}, train model finish!".format( self.max_iter))
def fit_binary(self, data_instances, validate_data=None): self.aggregator = aggregator.Host() self.aggregator.register_aggregator(self.transfer_variable) self._client_check_data(data_instances) self.callback_list.on_train_begin(data_instances, validate_data) pubkey = self.cipher.gen_paillier_pubkey(enable=self.use_encrypt, suffix=('fit',)) if self.use_encrypt: self.cipher_operator.set_public_key(pubkey) if not self.component_properties.is_warm_start: self.model_weights = self._init_model_variables(data_instances) if self.use_encrypt: w = self.cipher_operator.encrypt_list(self.model_weights.unboxed) else: w = list(self.model_weights.unboxed) self.model_weights = LogisticRegressionWeights(w, self.model_weights.fit_intercept) else: self.callback_warm_start_init_iter(self.n_iter_) # LOGGER.debug("After init, model_weights: {}".format(self.model_weights.unboxed)) mini_batch_obj = MiniBatch(data_inst=data_instances, batch_size=self.batch_size) total_batch_num = mini_batch_obj.batch_nums if self.use_encrypt: re_encrypt_times = (total_batch_num - 1) // self.re_encrypt_batches + 1 # LOGGER.debug("re_encrypt_times is :{}, batch_size: {}, total_batch_num: {}, re_encrypt_batches: {}".format( # re_encrypt_times, self.batch_size, total_batch_num, self.re_encrypt_batches)) self.cipher.set_re_cipher_time(re_encrypt_times) # total_data_num = data_instances.count() # LOGGER.debug("Current data count: {}".format(total_data_num)) model_weights = self.model_weights self.prev_round_weights = copy.deepcopy(model_weights) degree = 0 while self.n_iter_ < self.max_iter + 1: self.callback_list.on_epoch_begin(self.n_iter_) batch_data_generator = mini_batch_obj.mini_batch_data_generator() self.optimizer.set_iters(self.n_iter_) self.optimizer.set_iters(self.n_iter_) if ((self.n_iter_ + 1) % self.aggregate_iters == 0) or self.n_iter_ == self.max_iter: weight = self.aggregator.aggregate_then_get(model_weights, degree=degree, suffix=self.n_iter_) # LOGGER.debug("Before aggregate: {}, degree: {} after aggregated: {}".format( # model_weights.unboxed / degree, # degree, # weight.unboxed)) self.model_weights = LogisticRegressionWeights(weight.unboxed, self.fit_intercept) if not self.use_encrypt: loss = self._compute_loss(data_instances, self.prev_round_weights) self.aggregator.send_loss(loss, degree=degree, suffix=(self.n_iter_,)) LOGGER.info("n_iters: {}, loss: {}".format(self.n_iter_, loss)) degree = 0 self.is_converged = self.aggregator.get_converge_status(suffix=(self.n_iter_,)) LOGGER.info("n_iters: {}, is_converge: {}".format(self.n_iter_, self.is_converged)) if self.is_converged or self.n_iter_ == self.max_iter: break model_weights = self.model_weights batch_num = 0 for batch_data in batch_data_generator: n = batch_data.count() degree += n LOGGER.debug('before compute_gradient') f = functools.partial(self.gradient_operator.compute_gradient, coef=model_weights.coef_, intercept=model_weights.intercept_, fit_intercept=self.fit_intercept) grad = batch_data.applyPartitions(f).reduce(fate_operator.reduce_add) grad /= n if self.use_proximal: # use additional proximal term model_weights = self.optimizer.update_model(model_weights, grad=grad, has_applied=False, prev_round_weights=self.prev_round_weights) else: model_weights = self.optimizer.update_model(model_weights, grad=grad, has_applied=False) if self.use_encrypt and batch_num % self.re_encrypt_batches == 0: LOGGER.debug("Before accept re_encrypted_model, batch_iter_num: {}".format(batch_num)) w = self.cipher.re_cipher(w=model_weights.unboxed, iter_num=self.n_iter_, batch_iter_num=batch_num) model_weights = LogisticRegressionWeights(w, self.fit_intercept) batch_num += 1 # validation_strategy.validate(self, self.n_iter_) self.callback_list.on_epoch_end(self.n_iter_) self.n_iter_ += 1 if self.stop_training: break self.set_summary(self.get_model_summary()) LOGGER.info("Finish Training task, total iters: {}".format(self.n_iter_))
def fit(self, data_instances): """ Train lr model of role guest Parameters ---------- data_instances: DTable of Instance, input data """ LOGGER.info("Enter hetero_lr_guest fit") self._abnormal_detection(data_instances) self.header = self.get_header(data_instances) data_instances = data_instances.mapValues(HeteroLRGuest.load_data) # 获得密钥 public_key = federation.get( name=self.transfer_variable.paillier_pubkey.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.paillier_pubkey), idx=0) LOGGER.info("Get public_key from arbiter:{}".format(public_key)) self.encrypt_operator.set_public_key(public_key) LOGGER.info("Generate mini-batch from input data") mini_batch_obj = MiniBatch(data_instances, batch_size=self.batch_size) batch_num = mini_batch_obj.batch_nums if self.batch_size == -1: LOGGER.info( "batch size is -1, set it to the number of data in data_instances" ) self.batch_size = data_instances.count() batch_info = {"batch_size": self.batch_size, "batch_num": batch_num} LOGGER.info("batch_info:{}".format(batch_info)) federation.remote(batch_info, name=self.transfer_variable.batch_info.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.batch_info), role=consts.HOST, idx=0) LOGGER.info("Remote batch_info to Host") federation.remote(batch_info, name=self.transfer_variable.batch_info.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.batch_info), role=consts.ARBITER, idx=0) LOGGER.info("Remote batch_info to Arbiter") self.encrypted_calculator = [ EncryptModeCalculator( self.encrypt_operator, self.encrypted_mode_calculator_param.mode, self.encrypted_mode_calculator_param.re_encrypted_rate) for _ in range(batch_num) ] LOGGER.info("Start initialize model.") LOGGER.info("fit_intercept:{}".format( self.init_param_obj.fit_intercept)) model_shape = self.get_features_shape(data_instances) weight = self.initializer.init_model(model_shape, init_params=self.init_param_obj) if self.init_param_obj.fit_intercept is True: self.coef_ = weight[:-1] self.intercept_ = weight[-1] else: self.coef_ = weight is_send_all_batch_index = False self.n_iter_ = 0 index_data_inst_map = {} while self.n_iter_ < self.max_iter: LOGGER.info("iter:{}".format(self.n_iter_)) # each iter will get the same batch_data_generator batch_data_generator = mini_batch_obj.mini_batch_data_generator( result='index') batch_index = 0 for batch_data_index in batch_data_generator: LOGGER.info("batch:{}".format(batch_index)) if not is_send_all_batch_index: LOGGER.info("remote mini-batch index to Host") federation.remote( batch_data_index, name=self.transfer_variable.batch_data_index.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.batch_data_index, self.n_iter_, batch_index), role=consts.HOST, idx=0) if batch_index >= mini_batch_obj.batch_nums - 1: is_send_all_batch_index = True # Get mini-batch train data if len(index_data_inst_map) < batch_num: batch_data_inst = data_instances.join( batch_data_index, lambda data_inst, index: data_inst) index_data_inst_map[batch_index] = batch_data_inst else: batch_data_inst = index_data_inst_map[batch_index] # transforms features of raw input 'batch_data_inst' into more representative features 'batch_feat_inst' batch_feat_inst = self.transform(batch_data_inst) # guest/host forward self.compute_forward(batch_feat_inst, self.coef_, self.intercept_, batch_index) host_forward = federation.get( name=self.transfer_variable.host_forward_dict.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.host_forward_dict, self.n_iter_, batch_index), idx=0) LOGGER.info("Get host_forward from host") aggregate_forward_res = self.aggregate_forward(host_forward) en_aggregate_wx = aggregate_forward_res.mapValues( lambda v: v[0]) en_aggregate_wx_square = aggregate_forward_res.mapValues( lambda v: v[1]) # compute [[d]] if self.gradient_operator is None: self.gradient_operator = HeteroLogisticGradient( self.encrypt_operator) fore_gradient = self.gradient_operator.compute_fore_gradient( batch_feat_inst, en_aggregate_wx) federation.remote( fore_gradient, name=self.transfer_variable.fore_gradient.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.fore_gradient, self.n_iter_, batch_index), role=consts.HOST, idx=0) LOGGER.info("Remote fore_gradient to Host") # compute guest gradient and loss guest_gradient, loss = self.gradient_operator.compute_gradient_and_loss( batch_feat_inst, fore_gradient, en_aggregate_wx, en_aggregate_wx_square, self.fit_intercept) # loss regulation if necessary if self.updater is not None: guest_loss_regular = self.updater.loss_norm(self.coef_) loss += self.encrypt_operator.encrypt(guest_loss_regular) federation.remote( guest_gradient, name=self.transfer_variable.guest_gradient.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.guest_gradient, self.n_iter_, batch_index), role=consts.ARBITER, idx=0) LOGGER.info("Remote guest_gradient to arbiter") optim_guest_gradient = federation.get( name=self.transfer_variable.guest_optim_gradient.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.guest_optim_gradient, self.n_iter_, batch_index), idx=0) LOGGER.info("Get optim_guest_gradient from arbiter") # update model LOGGER.info("update_model") self.update_model(optim_guest_gradient) # update local model that transforms features of raw input 'batch_data_inst' training_info = { "iteration": self.n_iter_, "batch_index": batch_index } self.update_local_model(fore_gradient, batch_data_inst, self.coef_, **training_info) # Get loss regulation from Host if regulation is set if self.updater is not None: en_host_loss_regular = federation.get( name=self.transfer_variable.host_loss_regular.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.host_loss_regular, self.n_iter_, batch_index), idx=0) LOGGER.info("Get host_loss_regular from Host") loss += en_host_loss_regular federation.remote( loss, name=self.transfer_variable.loss.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.loss, self.n_iter_, batch_index), role=consts.ARBITER, idx=0) LOGGER.info("Remote loss to arbiter") # is converge of loss in arbiter batch_index += 1 # temporary resource recovery and will be removed in the future rubbish_list = [ host_forward, aggregate_forward_res, en_aggregate_wx, en_aggregate_wx_square, fore_gradient, self.guest_forward ] rubbish_clear(rubbish_list) is_stopped = federation.get( name=self.transfer_variable.is_stopped.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.is_stopped, self.n_iter_, batch_index), idx=0) LOGGER.info("Get is_stop flag from arbiter:{}".format(is_stopped)) self.n_iter_ += 1 if is_stopped: LOGGER.info( "Get stop signal from arbiter, model is converged, iter:{}" .format(self.n_iter_)) break LOGGER.info("Reach max iter {}, train model finish!".format( self.max_iter))
def fit(self, data_instances, validate_data=None): LOGGER.debug("Start data count: {}".format(data_instances.count())) self._abnormal_detection(data_instances) self.init_schema(data_instances) validation_strategy = self.init_validation_strategy(data_instances, validate_data) pubkey = self.cipher.gen_paillier_pubkey(enable=self.use_encrypt, suffix=('fit',)) if self.use_encrypt: self.cipher_operator.set_public_key(pubkey) self.model_weights = self._init_model_variables(data_instances) w = self.cipher_operator.encrypt_list(self.model_weights.unboxed) self.model_weights = LogisticRegressionWeights(w, self.model_weights.fit_intercept) LOGGER.debug("After init, model_weights: {}".format(self.model_weights.unboxed)) mini_batch_obj = MiniBatch(data_inst=data_instances, batch_size=self.batch_size) total_batch_num = mini_batch_obj.batch_nums if self.use_encrypt: re_encrypt_times = total_batch_num // self.re_encrypt_batches + 1 LOGGER.debug("re_encrypt_times is :{}, batch_size: {}, total_batch_num: {}, re_encrypt_batches: {}".format( re_encrypt_times, self.batch_size, total_batch_num, self.re_encrypt_batches)) self.cipher.set_re_cipher_time(re_encrypt_times) total_data_num = data_instances.count() LOGGER.debug("Current data count: {}".format(total_data_num)) model_weights = self.model_weights degree = 0 self.__synchronize_encryption() self.zcl_idx, self.zcl_num_party = self.transfer_variable.num_party.get(idx=0, suffix=('train',)) LOGGER.debug("party num:" + str(self.zcl_num_party)) self.__init_model() self.train_loss_results = [] self.train_accuracy_results = [] self.test_loss_results = [] self.test_accuracy_results = [] for iter_num in range(self.max_iter): # mini-batch LOGGER.debug("In iter: {}".format(iter_num)) # batch_data_generator = self.mini_batch_obj.mini_batch_data_generator() batch_num = 0 total_loss = 0 epoch_train_loss_avg = tfe.metrics.Mean() epoch_train_accuracy = tfe.metrics.Accuracy() for train_x, train_y in self.zcl_dataset: LOGGER.info("Staring batch {}".format(batch_num)) start_t = time.time() loss_value, grads = self.__grad(self.zcl_model, train_x, train_y) loss_value = loss_value.numpy() grads = [x.numpy() for x in grads] LOGGER.info("Start encrypting") loss_value = batch_encryption.encrypt(self.zcl_encrypt_operator.get_public_key(), loss_value) grads = [batch_encryption.encrypt_matrix(self.zcl_encrypt_operator.get_public_key(), x) for x in grads] LOGGER.info("Finish encrypting") grads = Gradients(grads) self.transfer_variable.host_grad.remote(obj=grads.for_remote(), role=consts.ARBITER, idx=0, suffix=(iter_num, batch_num)) LOGGER.info("Sent grads") self.transfer_variable.host_loss.remote(obj=loss_value, role=consts.ARBITER, idx=0, suffix=(iter_num, batch_num)) LOGGER.info("Sent loss") sum_grads = self.transfer_variable.aggregated_grad.get(idx=0, suffix=(iter_num, batch_num)) LOGGER.info("Got grads") sum_loss = self.transfer_variable.aggregated_loss.get(idx=0, suffix=(iter_num, batch_num)) LOGGER.info("Got loss") sum_loss = batch_encryption.decrypt(self.zcl_encrypt_operator.get_privacy_key(), sum_loss) sum_grads = [ batch_encryption.decrypt_matrix(self.zcl_encrypt_operator.get_privacy_key(), x).astype(np.float32) for x in sum_grads.unboxed] LOGGER.info("Finish decrypting") # sum_grads = np.array(sum_grads) / self.zcl_num_party self.zcl_optimizer.apply_gradients(zip(sum_grads, self.zcl_model.trainable_variables), self.zcl_global_step) elapsed_time = time.time() - start_t # epoch_train_loss_avg(loss_value) # epoch_train_accuracy(tf.argmax(self.zcl_model(train_x), axis=1, output_type=tf.int32), # train_y) self.train_loss_results.append(sum_loss) train_accuracy_v = accuracy_score(train_y, tf.argmax(self.zcl_model(train_x), axis=1, output_type=tf.int32)) self.train_accuracy_results.append(train_accuracy_v) test_loss_v = self.__loss(self.zcl_model, self.zcl_x_test, self.zcl_y_test) self.test_loss_results.append(test_loss_v) test_accuracy_v = accuracy_score(self.zcl_y_test, tf.argmax(self.zcl_model(self.zcl_x_test), axis=1, output_type=tf.int32)) self.test_accuracy_results.append(test_accuracy_v) LOGGER.info( "Epoch {:03d}, iteration {:03d}: train_loss: {:.3f}, train_accuracy: {:.3%}, test_loss: {:.3f}, " "test_accuracy: {:.3%}, elapsed_time: {:.4f}".format( iter_num, batch_num, sum_loss, train_accuracy_v, test_loss_v, test_accuracy_v, elapsed_time) ) batch_num += 1 if batch_num >= self.zcl_early_stop_batch: return self.n_iter_ = iter_num
def fit(self, data_instances): LOGGER.info("parameters: alpha: {}, eps: {}, max_iter: {}" "batch_size: {}".format(self.alpha, self.eps, self.max_iter, self.batch_size)) self.__init_parameters() w = self.__init_model(data_instances) mini_batch_obj = MiniBatch(data_inst=data_instances, batch_size=self.batch_size) for iter_num in range(self.max_iter): # mini-batch # LOGGER.debug("Enter iter_num: {}".format(iter_num)) batch_data_generator = mini_batch_obj.mini_batch_data_generator() total_loss = 0 batch_num = 0 for batch_data in batch_data_generator: f = functools.partial(self.gradient_operator.compute, coef=self.coef_, intercept=self.intercept_, fit_intercept=self.fit_intercept) grad_loss = batch_data.mapPartitions(f) n = grad_loss.count() grad, loss = grad_loss.reduce( self.aggregator.aggregate_grad_loss) grad /= n loss /= n if self.updater is not None: loss_norm = self.updater.loss_norm(self.coef_) total_loss += (loss + loss_norm) # LOGGER.debug("before update: {}".format(grad)) delta_grad = self.optimizer.apply_gradients(grad) # LOGGER.debug("after apply: {}".format(delta_grad)) self.update_model(delta_grad) batch_num += 1 total_loss /= batch_num w = self.merge_model() LOGGER.info("iter: {}, loss: {}".format(iter_num, total_loss)) # send model model_transfer_id = self.transfer_variable.generate_transferid( self.transfer_variable.guest_model, iter_num) federation.remote(w, name=self.transfer_variable.guest_model.name, tag=model_transfer_id, role=consts.ARBITER, idx=0) # send loss loss_transfer_id = self.transfer_variable.generate_transferid( self.transfer_variable.guest_loss, iter_num) federation.remote(total_loss, name=self.transfer_variable.guest_loss.name, tag=loss_transfer_id, role=consts.ARBITER, idx=0) # recv model model_transfer_id = self.transfer_variable.generate_transferid( self.transfer_variable.final_model, iter_num) w = federation.get(name=self.transfer_variable.final_model.name, tag=model_transfer_id, idx=0) w = np.array(w) # LOGGER.debug("Received final model: {}".format(w)) self.set_coef_(w) # recv converge flag converge_flag_id = self.transfer_variable.generate_transferid( self.transfer_variable.converge_flag, iter_num) converge_flag = federation.get( name=self.transfer_variable.converge_flag.name, tag=converge_flag_id, idx=0) self.n_iter_ = iter_num LOGGER.debug("converge flag is :{}".format(converge_flag)) if converge_flag: # self.save_model(w) break
def fit(self, data_instances): self._abnormal_detection(data_instances) self.header = data_instances.schema.get( 'header') # ['x1', 'x2', 'x3' ... ] self.__init_parameters() self.__init_model(data_instances) mini_batch_obj = MiniBatch(data_inst=data_instances, batch_size=self.batch_size) for iter_num in range(self.max_iter): # mini-batch batch_data_generator = mini_batch_obj.mini_batch_data_generator() total_loss = 0 batch_num = 0 for batch_data in batch_data_generator: n = batch_data.count() f = functools.partial(self.gradient_operator.compute, coef=self.coef_, intercept=self.intercept_, fit_intercept=self.fit_intercept) grad_loss = batch_data.mapPartitions(f) grad, loss = grad_loss.reduce( self.aggregator.aggregate_grad_loss) grad /= n loss /= n if self.updater is not None: loss_norm = self.updater.loss_norm(self.coef_) total_loss += (loss + loss_norm) delta_grad = self.optimizer.apply_gradients(grad) self.update_model(delta_grad) batch_num += 1 total_loss /= batch_num w = self.merge_model() self.loss_history.append(total_loss) LOGGER.info("iter: {}, loss: {}".format(iter_num, total_loss)) # send model model_transfer_id = self.transfer_variable.generate_transferid( self.transfer_variable.guest_model, iter_num) federation.remote(w, name=self.transfer_variable.guest_model.name, tag=model_transfer_id, role=consts.ARBITER, idx=0) # send loss loss_transfer_id = self.transfer_variable.generate_transferid( self.transfer_variable.guest_loss, iter_num) federation.remote(total_loss, name=self.transfer_variable.guest_loss.name, tag=loss_transfer_id, role=consts.ARBITER, idx=0) # recv model model_transfer_id = self.transfer_variable.generate_transferid( self.transfer_variable.final_model, iter_num) w = federation.get(name=self.transfer_variable.final_model.name, tag=model_transfer_id, idx=0) w = np.array(w) self.set_coef_(w) # recv converge flag converge_flag_id = self.transfer_variable.generate_transferid( self.transfer_variable.converge_flag, iter_num) converge_flag = federation.get( name=self.transfer_variable.converge_flag.name, tag=converge_flag_id, idx=0) self.n_iter_ = iter_num LOGGER.debug("converge flag is :{}".format(converge_flag)) if converge_flag: self.is_converged = True break self.show_meta() self.show_model() LOGGER.debug("in fit self coef: {}".format(self.coef_)) return data_instances
def fit(self, data_instances, validate_data=None): self._abnormal_detection(data_instances) self.init_schema(data_instances) validation_strategy = self.init_validation_strategy( data_instances, validate_data) self.model_weights = self._init_model_variables(data_instances) max_iter = self.max_iter total_data_num = data_instances.count() mini_batch_obj = MiniBatch(data_inst=data_instances, batch_size=self.batch_size) model_weights = self.model_weights self.__synchronize_encryption() self.zcl_idx, self.zcl_num_party = self.transfer_variable.num_party.get( idx=0, suffix=('train', )) LOGGER.debug("party num:" + str(self.zcl_num_party)) self.__init_model() self.train_loss_results = [] self.train_accuracy_results = [] self.test_loss_results = [] self.test_accuracy_results = [] batch_num = 0 for iter_num in range(self.max_iter): total_loss = 0 # batch_num = 0 iter_num_ = 0 epoch_train_loss_avg = tfe.metrics.Mean() epoch_train_accuracy = tfe.metrics.Accuracy() for train_x, train_y in self.zcl_dataset: LOGGER.info("Staring batch {}".format(batch_num)) start_t = time.time() loss_value, grads = self.__grad(self.zcl_model, train_x, train_y) loss_value = loss_value.numpy() grads = [x.numpy() for x in grads] sizes = [layer.size * self.zcl_num_party for layer in grads] guest_max = [np.max(layer) for layer in grads] guest_min = [np.min(layer) for layer in grads] # clipping_thresholds_guest = batch_encryption.calculate_clip_threshold_aciq_l(grads, bit_width=self.bit_width) grad_max_all = self.transfer_variable.host_grad_max.get( idx=-1, suffix=(iter_num_, batch_num)) grad_min_all = self.transfer_variable.host_grad_min.get( idx=-1, suffix=(iter_num_, batch_num)) grad_max_all.append(guest_max) grad_min_all.append(guest_min) max_v = [] min_v = [] for layer_idx in range(len(grads)): max_v.append( [np.max([party[layer_idx] for party in grad_max_all])]) min_v.append( [np.min([party[layer_idx] for party in grad_min_all])]) grads_max_min = np.concatenate( [np.array(max_v), np.array(min_v)], axis=1) clipping_thresholds = batch_encryption.calculate_clip_threshold_aciq_g( grads_max_min, sizes, bit_width=self.bit_width) LOGGER.info("clipping threshold " + str(clipping_thresholds)) r_maxs = [x * self.zcl_num_party for x in clipping_thresholds] self.transfer_variable.clipping_threshold.remote( obj=clipping_thresholds, role=consts.HOST, idx=-1, suffix=(iter_num_, batch_num)) grads = batch_encryption.clip_with_threshold( grads, clipping_thresholds) LOGGER.info("Start batch encrypting") loss_value = batch_encryption.encrypt( self.zcl_encrypt_operator.get_public_key(), loss_value) # grads = [batch_encryption.encrypt_matrix(self.zcl_encrypt_operator.get_public_key(), x) for x in grads] enc_grads, og_shape = batch_encryption.batch_enc_per_layer( publickey=self.zcl_encrypt_operator.get_public_key(), party=grads, r_maxs=r_maxs, bit_width=self.bit_width, batch_size=self.e_batch_size) # grads = Gradients(enc_grads) LOGGER.info("Finish encrypting") # grads = self.encrypt_operator.get_public_key() # self.transfer_variable.guest_grad.remote(obj=grads.for_remote(), role=consts.ARBITER, idx=0, # suffix=(iter_num_, batch_num)) self.transfer_variable.guest_grad.remote(obj=enc_grads, role=consts.ARBITER, idx=0, suffix=(iter_num_, batch_num)) LOGGER.info("Sent grads") self.transfer_variable.guest_loss.remote(obj=loss_value, role=consts.ARBITER, idx=0, suffix=(iter_num_, batch_num)) LOGGER.info("Sent loss") sum_grads = self.transfer_variable.aggregated_grad.get( idx=0, suffix=(iter_num_, batch_num)) LOGGER.info("Got grads") sum_loss = self.transfer_variable.aggregated_loss.get( idx=0, suffix=(iter_num_, batch_num)) LOGGER.info("Got loss") sum_loss = batch_encryption.decrypt( self.zcl_encrypt_operator.get_privacy_key(), sum_loss) # sum_grads = [ # batch_encryption.decrypt_matrix(self.zcl_encrypt_operator.get_privacy_key(), x).astype(np.float32) for x # in sum_grads.unboxed] sum_grads = batch_encryption.batch_dec_per_layer( privatekey=self.zcl_encrypt_operator.get_privacy_key(), # party=sum_grads.unboxed, og_shapes=og_shape, party=sum_grads, og_shapes=og_shape, r_maxs=r_maxs, bit_width=self.bit_width, batch_size=self.e_batch_size) LOGGER.info("Finish decrypting") # sum_grads = np.array(sum_grads) / self.zcl_num_party self.zcl_optimizer.apply_gradients( zip(sum_grads, self.zcl_model.trainable_variables), self.zcl_global_step) elapsed_time = time.time() - start_t # epoch_train_loss_avg(loss_value) # epoch_train_accuracy(tf.argmax(self.zcl_model(train_x), axis=1, output_type=tf.int32), # train_y) self.train_loss_results.append(sum_loss) train_accuracy_v = accuracy_score( train_y, tf.argmax(self.zcl_model(train_x), axis=1, output_type=tf.int32)) self.train_accuracy_results.append(train_accuracy_v) test_loss_v = self.__loss(self.zcl_model, self.zcl_x_test, self.zcl_y_test) self.test_loss_results.append(test_loss_v) test_accuracy_v = accuracy_score( self.zcl_y_test, tf.argmax(self.zcl_model(self.zcl_x_test), axis=1, output_type=tf.int32)) self.test_accuracy_results.append(test_accuracy_v) LOGGER.info( "Epoch {:03d}, iteration {:03d}: train_loss: {:.3f}, train_accuracy: {:.3%}, test_loss: {:.3f}, " "test_accuracy: {:.3%}, elapsed_time: {:.4f}".format( iter_num, batch_num, sum_loss, train_accuracy_v, test_loss_v, test_accuracy_v, elapsed_time)) batch_num += 1 # if batch_num >= self.zcl_early_stop_batch: # return self.n_iter_ = iter_num
class Guest(batch_info_sync.Guest): def __init__(self): self.mini_batch_obj = None self.finish_sycn = False self.batch_nums = None self.batch_masked = False def register_batch_generator(self, transfer_variables, has_arbiter=True): self._register_batch_data_index_transfer(transfer_variables.batch_info, transfer_variables.batch_data_index, getattr(transfer_variables, "batch_validate_info", None), has_arbiter) def initialize_batch_generator(self, data_instances, batch_size, suffix=tuple(), shuffle=False, batch_strategy="full", masked_rate=0): self.mini_batch_obj = MiniBatch(data_instances, batch_size=batch_size, shuffle=shuffle, batch_strategy=batch_strategy, masked_rate=masked_rate) self.batch_nums = self.mini_batch_obj.batch_nums self.batch_masked = self.mini_batch_obj.batch_size != self.mini_batch_obj.masked_batch_size batch_info = {"batch_size": self.mini_batch_obj.batch_size, "batch_num": self.batch_nums, "batch_mutable": self.mini_batch_obj.batch_mutable, "masked_batch_size": self.mini_batch_obj.masked_batch_size} self.sync_batch_info(batch_info, suffix) if not self.mini_batch_obj.batch_mutable: self.prepare_batch_data(suffix) def prepare_batch_data(self, suffix=tuple()): self.mini_batch_obj.generate_batch_data() index_generator = self.mini_batch_obj.mini_batch_data_generator(result='index') batch_index = 0 for batch_data_index in index_generator: batch_suffix = suffix + (batch_index,) self.sync_batch_index(batch_data_index, batch_suffix) batch_index += 1 def generate_batch_data(self, with_index=False, suffix=tuple()): if self.mini_batch_obj.batch_mutable: self.prepare_batch_data(suffix) if with_index: data_generator = self.mini_batch_obj.mini_batch_data_generator(result='both') for batch_data, index_data in data_generator: yield batch_data, index_data else: data_generator = self.mini_batch_obj.mini_batch_data_generator(result='data') for batch_data in data_generator: yield batch_data def verify_batch_legality(self, suffix=tuple()): validate_infos = self.sync_batch_validate_info(suffix) least_batch_size = 0 is_legal = True for validate_info in validate_infos: legality = validate_info.get("legality") if not legality: is_legal = False least_batch_size = max(least_batch_size, validate_info.get("least_batch_size")) if not is_legal: raise ValueError(f"To use batch masked strategy, " f"(masked_rate + 1) * batch_size should > {least_batch_size}")
def fit(self, data_instances, validate_data=None): self._abnormal_detection(data_instances) self.init_schema(data_instances) validation_strategy = self.init_validation_strategy( data_instances, validate_data) self.model_weights = self._init_model_variables(data_instances) max_iter = self.max_iter # total_data_num = data_instances.count() mini_batch_obj = MiniBatch(data_inst=data_instances, batch_size=self.batch_size) model_weights = self.model_weights degree = 0 while self.n_iter_ < max_iter: batch_data_generator = mini_batch_obj.mini_batch_data_generator() self.optimizer.set_iters(self.n_iter_) if self.n_iter_ > 0 and self.n_iter_ % self.aggregate_iters == 0: weight = self.aggregator.aggregate_then_get( model_weights, degree=degree, suffix=self.n_iter_) LOGGER.debug( "Before aggregate: {}, degree: {} after aggregated: {}". format(model_weights.unboxed / degree, degree, weight.unboxed)) self.model_weights = LogisticRegressionWeights( weight.unboxed, self.fit_intercept) loss = self._compute_loss(data_instances) self.aggregator.send_loss(loss, degree=degree, suffix=(self.n_iter_, )) degree = 0 self.is_converged = self.aggregator.get_converge_status( suffix=(self.n_iter_, )) LOGGER.info( "n_iters: {}, loss: {} converge flag is :{}".format( self.n_iter_, loss, self.is_converged)) if self.is_converged: break model_weights = self.model_weights batch_num = 0 for batch_data in batch_data_generator: n = batch_data.count() LOGGER.debug( "In each batch, lr_weight: {}, batch_data count: {}". format(model_weights.unboxed, n)) f = functools.partial(self.gradient_operator.compute_gradient, coef=model_weights.coef_, intercept=model_weights.intercept_, fit_intercept=self.fit_intercept) grad = batch_data.mapPartitions(f).reduce( fate_operator.reduce_add) grad /= n LOGGER.debug( 'iter: {}, batch_index: {}, grad: {}, n: {}'.format( self.n_iter_, batch_num, grad, n)) model_weights = self.optimizer.update_model(model_weights, grad, has_applied=False) batch_num += 1 degree += n validation_strategy.validate(self, self.n_iter_) self.n_iter_ += 1
def fit(self, data_instances): if not self.need_run: return data_instances self._abnormal_detection(data_instances) self.init_schema(data_instances) self.__init_parameters() self.__init_model(data_instances) mini_batch_obj = MiniBatch(data_inst=data_instances, batch_size=self.batch_size) for iter_num in range(self.max_iter): # mini-batch batch_data_generator = mini_batch_obj.mini_batch_data_generator() total_loss = 0 batch_num = 0 for batch_data in batch_data_generator: n = batch_data.count() f = functools.partial(self.gradient_operator.compute, coef=self.coef_, intercept=self.intercept_, fit_intercept=self.fit_intercept) grad_loss = batch_data.mapPartitions(f) grad, loss = grad_loss.reduce( self.aggregator.aggregate_grad_loss) grad /= n loss /= n if self.updater is not None: loss_norm = self.updater.loss_norm(self.coef_) total_loss += (loss + loss_norm) delta_grad = self.optimizer.apply_gradients(grad) self.update_model(delta_grad) batch_num += 1 total_loss /= batch_num # if not self.use_loss: # total_loss = np.linalg.norm(self.coef_) w = self.merge_model() if not self.need_one_vs_rest: metric_meta = MetricMeta(name='train', metric_type="LOSS", extra_metas={ "unit_name": "iters", }) # metric_name = self.get_metric_name('loss') self.callback_meta(metric_name='loss', metric_namespace='train', metric_meta=metric_meta) self.callback_metric( metric_name='loss', metric_namespace='train', metric_data=[Metric(iter_num, total_loss)]) self.loss_history.append(total_loss) LOGGER.info("iter: {}, loss: {}".format(iter_num, total_loss)) # send model model_transfer_id = self.transfer_variable.generate_transferid( self.transfer_variable.guest_model, iter_num) LOGGER.debug("Start to remote model: {}, transfer_id: {}".format( w, model_transfer_id)) federation.remote(w, name=self.transfer_variable.guest_model.name, tag=model_transfer_id, role=consts.ARBITER, idx=0) # send loss # if self.use_loss: loss_transfer_id = self.transfer_variable.generate_transferid( self.transfer_variable.guest_loss, iter_num) LOGGER.debug( "Start to remote total_loss: {}, transfer_id: {}".format( total_loss, loss_transfer_id)) federation.remote(total_loss, name=self.transfer_variable.guest_loss.name, tag=loss_transfer_id, role=consts.ARBITER, idx=0) # recv model model_transfer_id = self.transfer_variable.generate_transferid( self.transfer_variable.final_model, iter_num) w = federation.get(name=self.transfer_variable.final_model.name, tag=model_transfer_id, idx=0) w = np.array(w) self.set_coef_(w) # recv converge flag converge_flag_id = self.transfer_variable.generate_transferid( self.transfer_variable.converge_flag, iter_num) converge_flag = federation.get( name=self.transfer_variable.converge_flag.name, tag=converge_flag_id, idx=0) self.n_iter_ = iter_num LOGGER.debug("converge flag is :{}".format(converge_flag)) if converge_flag: self.is_converged = True break
def fit(self, data_instances, validate_data=None): LOGGER.debug("Start data count: {}".format(data_instances.count())) self._abnormal_detection(data_instances) self.init_schema(data_instances) # validation_strategy = self.init_validation_strategy(data_instances, validate_data) pubkey = self.cipher.gen_paillier_pubkey(enable=self.use_encrypt, suffix=('fit', )) if self.use_encrypt: self.cipher_operator.set_public_key(pubkey) self.model_weights = self._init_model_variables(data_instances) w = self.cipher_operator.encrypt_list(self.model_weights.unboxed) self.model_weights = LogisticRegressionWeights( w, self.model_weights.fit_intercept) LOGGER.debug("After init, model_weights: {}".format( self.model_weights.unboxed)) mini_batch_obj = MiniBatch(data_inst=data_instances, batch_size=self.batch_size) total_batch_num = mini_batch_obj.batch_nums if self.use_encrypt: re_encrypt_times = (total_batch_num - 1) // self.re_encrypt_batches + 1 LOGGER.debug( "re_encrypt_times is :{}, batch_size: {}, total_batch_num: {}, re_encrypt_batches: {}" .format(re_encrypt_times, self.batch_size, total_batch_num, self.re_encrypt_batches)) self.cipher.set_re_cipher_time(re_encrypt_times) total_data_num = data_instances.count() LOGGER.debug("Current data count: {}".format(total_data_num)) model_weights = self.model_weights degree = 0 while self.n_iter_ < self.max_iter + 1: batch_data_generator = mini_batch_obj.mini_batch_data_generator() if ((self.n_iter_ + 1) % self.aggregate_iters == 0) or self.n_iter_ == self.max_iter: weight = self.aggregator.aggregate_then_get( model_weights, degree=degree, suffix=self.n_iter_) # LOGGER.debug("Before aggregate: {}, degree: {} after aggregated: {}".format( # model_weights.unboxed / degree, # degree, # weight.unboxed)) self.model_weights = LogisticRegressionWeights( weight.unboxed, self.fit_intercept) if not self.use_encrypt: loss = self._compute_loss(data_instances) self.aggregator.send_loss(loss, degree=degree, suffix=(self.n_iter_, )) LOGGER.info("n_iters: {}, loss: {}".format( self.n_iter_, loss)) degree = 0 self.is_converged = self.aggregator.get_converge_status( suffix=(self.n_iter_, )) LOGGER.info("n_iters: {}, is_converge: {}".format( self.n_iter_, self.is_converged)) if self.is_converged or self.n_iter_ == self.max_iter: break model_weights = self.model_weights batch_num = 0 for batch_data in batch_data_generator: n = batch_data.count() degree += n LOGGER.debug('before compute_gradient') f = functools.partial(self.gradient_operator.compute_gradient, coef=model_weights.coef_, intercept=model_weights.intercept_, fit_intercept=self.fit_intercept) grad = batch_data.mapPartitions(f).reduce( fate_operator.reduce_add) grad /= n model_weights = self.optimizer.update_model(model_weights, grad, has_applied=False) if self.use_encrypt and batch_num % self.re_encrypt_batches == 0: LOGGER.debug( "Before accept re_encrypted_model, batch_iter_num: {}". format(batch_num)) w = self.cipher.re_cipher(w=model_weights.unboxed, iter_num=self.n_iter_, batch_iter_num=batch_num) model_weights = LogisticRegressionWeights( w, self.fit_intercept) batch_num += 1 # validation_strategy.validate(self, self.n_iter_) self.n_iter_ += 1 LOGGER.info("Finish Training task, total iters: {}".format( self.n_iter_))
def fit(self, data_instances, node2id, local_instances=None, common_nodes=None): """ Train ne model pf role host Parameters ---------- data_instances: Dtable of anchor node, input data """ LOGGER.info("Enter hetero_ne host") self.n_node = len(node2id) LOGGER.info("Host party has {} nodes".format(self.n_node)) data_instances = data_instances.mapValues(HeteroNEHost.load_data) LOGGER.info("Transform input data to train instance") public_key = federation.get( name=self.transfer_variable.paillier_pubkey.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.paillier_pubkey), idx=0) LOGGER.info("Get Publick key from arbiter:{}".format(public_key)) self.encrypt_operator.set_public_key(public_key) ############## # horizontal federated learning LOGGER.info("Generate mini-batch for local instances in guest") mini_batch_obj_local = MiniBatch(local_instances, batch_size=self.batch_size) common_node_instances = eggroll.parallelize( ((node, node) for node in common_nodes), include_key=True, name='common_nodes') ############## batch_info = federation.get( name=self.transfer_variable.batch_info.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.batch_info), idx=0) LOGGER.info("Get batch_info from guest: {}".format(batch_info)) self.batch_size = batch_info['batch_size'] self.batch_num = batch_info['batch_num'] if self.batch_size < consts.MIN_BATCH_SIZE and self.batch_size != -1: raise ValueError( "Batch size get from guest should not less than 10, except -1, batch_size is {}" .format(self.batch_size)) self.encrypted_calculator = [ EncryptModeCalculator( self.encrypt_operator, self.encrypted_mode_calculator_param.mode, self.encrypted_mode_calculator_param.re_encrypted_rate) for _ in range(self.batch_num) ] LOGGER.info("Start initilize model.") self.embedding_ = self.initializer.init_model((self.n_node, self.dim), self.init_param_obj) self.n_iter_ = 0 index_data_inst_map = {} while self.n_iter_ < self.max_iter: LOGGER.info("iter: {}".format(self.n_iter_)) ################# local_batch_data_generator = mini_batch_obj_local.mini_batch_data_generator( ) total_loss = 0 local_batch_num = 0 LOGGER.info("Horizontally learning") for local_batch_data in local_batch_data_generator: n = local_batch_data.count() LOGGER.info("Local batch data count:{}".format(n)) E_Y = self.compute_local_embedding(local_batch_data, self.embedding_, node2id) local_grads_e1, local_grads_e2, local_loss = self.local_gradient_operator.compute( E_Y, 'E_1') local_grads_e1 = local_grads_e1.mapValues( lambda g: self.local_optimizer.apply_gradients(g / n)) local_grads_e2 = local_grads_e2.mapValues( lambda g: self.local_optimizer.apply_gradients(g / n)) e1id_join_grads = local_batch_data.join( local_grads_e1, lambda v, g: (node2id[v[0]], g)) e2id_join_grads = local_batch_data.join( local_grads_e2, lambda v, g: (node2id[v[1]], g)) self.update_model(e1id_join_grads) self.update_model(e2id_join_grads) local_loss = local_loss / n local_batch_num += 1 total_loss += local_loss LOGGER.info("gradient count:{}".format( e1id_join_grads.count())) host_common_embedding = common_node_instances.mapValues( lambda node: self.embedding_[node2id[node]]) federation.remote( host_common_embedding, name=self.transfer_variable.host_common_embedding.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.host_common_embedding, self.n_iter_, 0), role=consts.ARBITER, idx=0) common_embedding = federation.get( name=self.transfer_variable.common_embedding.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.common_embedding, self.n_iter_, 0), idx=0) self.update_common_nodes(common_embedding, common_nodes, node2id) total_loss /= local_batch_num LOGGER.info("Iter {}, Local loss: {}".format( self.n_iter_, total_loss)) batch_index = 0 while batch_index < self.batch_num: LOGGER.info("batch:{}".format(batch_index)) # set batch_data # in order to avoid communicating in next iteration # in next iteration, the sequence of batches is the same if len(self.batch_index_list) < self.batch_num: batch_data_index = federation.get( name=self.transfer_variable.batch_data_index.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.batch_data_index, self.n_iter_, batch_index), idx=0) LOGGER.info("Get batch_index from Guest") self.batch_index_list.append(batch_index) else: batch_data_index = self.batch_index_list[batch_index] # Get mini-batch train_data # in order to avoid joining for next iteration if len(index_data_inst_map) < self.batch_num: batch_data_inst = batch_data_index.join( data_instances, lambda g, d: d) index_data_inst_map[batch_index] = batch_data_inst else: batch_data_inst = index_data_inst_map[batch_data_index] LOGGER.info("batch_data_inst size:{}".format( batch_data_inst.count())) #self.transform(data_inst) # compute forward host_forward = self.compute_forward(batch_data_inst, self.embedding_, node2id, batch_index) federation.remote( host_forward, name=self.transfer_variable.host_forward_dict.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.host_forward_dict, self.n_iter_, batch_index), role=consts.GUEST, idx=0) LOGGER.info("Remote host_forward to guest") # Get optimize host gradient and update model optim_host_gradient = federation.get( name=self.transfer_variable.host_optim_gradient.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.host_optim_gradient, self.n_iter_, batch_index), idx=0) LOGGER.info("Get optim_host_gradient from arbiter") nodeid_join_gradient = batch_data_inst.join( optim_host_gradient, lambda instance, gradient: (node2id[instance.features], gradient)) LOGGER.info("update_model") self.update_model(nodeid_join_gradient) # update local model #training_info = {"iteration": self.n_iter_, "batch_index": batch_index} #self.update_local_model(fore_gradient, batch_data_inst, self.coef_, **training_info) batch_index += 1 rubbish_list = [host_forward] rubbish_clear(rubbish_list) ####### host_common_embedding = common_node_instances.mapValues( lambda node: self.embedding_[node2id[node]]) federation.remote( host_common_embedding, name=self.transfer_variable.host_common_embedding.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.host_common_embedding, self.n_iter_, 1), role=consts.ARBITER, idx=0) common_embedding = federation.get( name=self.transfer_variable.common_embedding.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.common_embedding, self.n_iter_, 1), idx=0) self.update_common_nodes(common_embedding, common_nodes, node2id) ####### is_stopped = federation.get( name=self.transfer_variable.is_stopped.name, tag=self.transfer_variable.generate_transferid( self.transfer_variable.is_stopped, self.n_iter_, ), idx=0) LOGGER.info("Get is_stop flag from arbiter:{}".format(is_stopped)) self.n_iter_ += 1 if is_stopped: break LOGGER.info("Reach max iter {}, train mode finish!".format( self.max_iter)) embedding_table = eggroll.table(name='host', namespace='node_embedding', partition=10) id2node = dict(zip(node2id.values(), node2id.keys())) for id, embedding in enumerate(self.embedding_): embedding_table.put(id2node[id], embedding) embedding_table.save_as(name='host', namespace='node_embedding', partition=10) LOGGER.info("Reach max iter {}, train model finish!".format( self.max_iter))