Exemplo n.º 1
0
class Guest(batch_info_sync.Guest):
    def __init__(self):
        self.mini_batch_obj = None
        self.finish_sycn = False
        self.batch_nums = None

    def register_batch_generator(self, transfer_variables, has_arbiter=True):
        self._register_batch_data_index_transfer(
            transfer_variables.batch_info, transfer_variables.batch_data_index,
            has_arbiter)

    def initialize_batch_generator(self,
                                   data_instances,
                                   batch_size,
                                   suffix=tuple()):
        self.mini_batch_obj = MiniBatch(data_instances, batch_size=batch_size)
        self.batch_nums = self.mini_batch_obj.batch_nums
        batch_info = {"batch_size": batch_size, "batch_num": self.batch_nums}
        self.sync_batch_info(batch_info, suffix)
        index_generator = self.mini_batch_obj.mini_batch_data_generator(
            result='index')
        batch_index = 0
        for batch_data_index in index_generator:
            batch_suffix = suffix + (batch_index, )
            self.sync_batch_index(batch_data_index, batch_suffix)
            batch_index += 1

    def generate_batch_data(self):
        data_generator = self.mini_batch_obj.mini_batch_data_generator(
            result='data')
        for batch_data in data_generator:
            yield batch_data
Exemplo n.º 2
0
    def test_mini_batch_data_generator(self, data_num=100, batch_size=320):
        t0 = time.time()
        feature_num = 20

        expect_batches = data_num // batch_size
        # print("expect_batches: {}".format(expect_batches))
        data_instances = self.prepare_data(data_num=data_num, feature_num=feature_num)
        # print("Prepare data time: {}".format(time.time() - t0))
        mini_batch_obj = MiniBatch(data_inst=data_instances, batch_size=batch_size)
        batch_data_generator = mini_batch_obj.mini_batch_data_generator()
        batch_id = 0
        pre_time = time.time() - t0
        # print("Prepare mini batch time: {}".format(pre_time))
        total_num = 0
        for batch_data in batch_data_generator:
            batch_num = batch_data.count()
            if batch_id < expect_batches - 1:
                # print("In mini batch test, batch_num: {}, batch_size:{}".format(
                #     batch_num, batch_size
                # ))
                self.assertEqual(batch_num, batch_size)

            batch_id += 1
            total_num += batch_num
            # curt_time = time.time()
            # print("One batch time: {}".format(curt_time - pre_time))
            # pre_time = curt_time
        self.assertEqual(total_num, data_num)
Exemplo n.º 3
0
    def initialize_batch_generator(self, data_instances, batch_size, suffix=tuple(),
                                   shuffle=False, batch_strategy="full", masked_rate=0):
        self.mini_batch_obj = MiniBatch(data_instances, batch_size=batch_size, shuffle=shuffle,
                                        batch_strategy=batch_strategy, masked_rate=masked_rate)
        self.batch_nums = self.mini_batch_obj.batch_nums
        self.batch_masked = self.mini_batch_obj.batch_size != self.mini_batch_obj.masked_batch_size
        batch_info = {"batch_size": self.mini_batch_obj.batch_size, "batch_num": self.batch_nums,
                      "batch_mutable": self.mini_batch_obj.batch_mutable,
                      "masked_batch_size": self.mini_batch_obj.masked_batch_size}
        self.sync_batch_info(batch_info, suffix)

        if not self.mini_batch_obj.batch_mutable:
            self.prepare_batch_data(suffix)
Exemplo n.º 4
0
 def initialize_batch_generator(self,
                                data_instances,
                                batch_size,
                                suffix=tuple()):
     self.mini_batch_obj = MiniBatch(data_instances, batch_size=batch_size)
     self.batch_nums = self.mini_batch_obj.batch_nums
     batch_info = {"batch_size": batch_size, "batch_num": self.batch_nums}
     self.sync_batch_info(batch_info, suffix)
     index_generator = self.mini_batch_obj.mini_batch_data_generator(
         result='index')
     batch_index = 0
     for batch_data_index in index_generator:
         batch_suffix = suffix + (batch_index, )
         self.sync_batch_index(batch_data_index, batch_suffix)
         batch_index += 1
Exemplo n.º 5
0
    def __init_parameters(self, data_instances):

        party_weight_id = self.transfer_variable.generate_transferid(
            self.transfer_variable.host_party_weight
        )
        # LOGGER.debug("party_weight_id: {}".format(party_weight_id))
        federation.remote(self.party_weight,
                          name=self.transfer_variable.host_party_weight.name,
                          tag=party_weight_id,
                          role=consts.ARBITER,
                          idx=0)

        self.__synchronize_encryption()

        # Send re-encrypt times
        self.mini_batch_obj = MiniBatch(data_inst=data_instances, batch_size=self.batch_size)
        if self.use_encrypt:
            # LOGGER.debug("Use encryption, send re_encrypt_times")
            total_batch_num = self.mini_batch_obj.batch_nums
            re_encrypt_times = total_batch_num // self.re_encrypt_batches
            transfer_id = self.transfer_variable.generate_transferid(self.transfer_variable.re_encrypt_times)
            federation.remote(re_encrypt_times,
                              name=self.transfer_variable.re_encrypt_times.name,
                              tag=transfer_id,
                              role=consts.ARBITER,
                              idx=0)
            LOGGER.info("sent re_encrypt_times: {}".format(re_encrypt_times))
Exemplo n.º 6
0
    def fit(self, data_instances, validate_data=None):

        self._abnormal_detection(data_instances)
        self.init_schema(data_instances)

        validation_strategy = self.init_validation_strategy(
            data_instances, validate_data)
        self.model_weights = self._init_model_variables(data_instances)

        max_iter = self.max_iter
        total_data_num = data_instances.count()
        mini_batch_obj = MiniBatch(data_inst=data_instances,
                                   batch_size=self.batch_size)
        model_weights = self.model_weights

        self.__synchronize_encryption()
        self.zcl_idx, self.zcl_num_party = self.transfer_variable.num_party.get(
            idx=0, suffix=('train', ))
        LOGGER.debug("party num:" + str(self.zcl_num_party))
        self.__init_model()

        self.train_loss_results = []
        self.train_accuracy_results = []
        self.test_loss_results = []
        self.test_accuracy_results = []

        for iter_num in range(self.max_iter):
            total_loss = 0
            batch_num = 0
            epoch_train_loss_avg = tfe.metrics.Mean()
            epoch_train_accuracy = tfe.metrics.Accuracy()

            for train_x, train_y in self.zcl_dataset:
                LOGGER.info("Staring batch {}".format(batch_num))
                start_t = time.time()
                loss_value, grads = self.__grad(self.zcl_model, train_x,
                                                train_y)
                loss_value = loss_value.numpy()
                grads = [x.numpy() for x in grads]
                LOGGER.info("Start encrypting")
                loss_value = batch_encryption.encrypt(
                    self.zcl_encrypt_operator.get_public_key(), loss_value)
                grads = [
                    batch_encryption.encrypt_matrix(
                        self.zcl_encrypt_operator.get_public_key(), x)
                    for x in grads
                ]
                grads = Gradients(grads)
                LOGGER.info("Finish encrypting")
                # grads = self.encrypt_operator.get_public_key()
                self.transfer_variable.guest_grad.remote(
                    obj=grads.for_remote(),
                    role=consts.ARBITER,
                    idx=0,
                    suffix=(iter_num, batch_num))
                LOGGER.info("Sent grads")
                self.transfer_variable.guest_loss.remote(obj=loss_value,
                                                         role=consts.ARBITER,
                                                         idx=0,
                                                         suffix=(iter_num,
                                                                 batch_num))
                LOGGER.info("Sent loss")

                sum_grads = self.transfer_variable.aggregated_grad.get(
                    idx=0, suffix=(iter_num, batch_num))
                LOGGER.info("Got grads")
                sum_loss = self.transfer_variable.aggregated_loss.get(
                    idx=0, suffix=(iter_num, batch_num))
                LOGGER.info("Got loss")

                sum_loss = batch_encryption.decrypt(
                    self.zcl_encrypt_operator.get_privacy_key(), sum_loss)
                sum_grads = [
                    batch_encryption.decrypt_matrix(
                        self.zcl_encrypt_operator.get_privacy_key(),
                        x).astype(np.float32) for x in sum_grads.unboxed
                ]
                LOGGER.info("Finish decrypting")

                # sum_grads = np.array(sum_grads) / self.zcl_num_party

                self.zcl_optimizer.apply_gradients(
                    zip(sum_grads, self.zcl_model.trainable_variables),
                    self.zcl_global_step)

                elapsed_time = time.time() - start_t
                # epoch_train_loss_avg(loss_value)
                # epoch_train_accuracy(tf.argmax(self.zcl_model(train_x), axis=1, output_type=tf.int32),
                #                      train_y)
                self.train_loss_results.append(sum_loss)
                train_accuracy_v = accuracy_score(
                    train_y,
                    tf.argmax(self.zcl_model(train_x),
                              axis=1,
                              output_type=tf.int32))
                self.train_accuracy_results.append(train_accuracy_v)
                test_loss_v = self.__loss(self.zcl_model, self.zcl_x_test,
                                          self.zcl_y_test)
                self.test_loss_results.append(test_loss_v)
                test_accuracy_v = accuracy_score(
                    self.zcl_y_test,
                    tf.argmax(self.zcl_model(self.zcl_x_test),
                              axis=1,
                              output_type=tf.int32))
                self.test_accuracy_results.append(test_accuracy_v)

                LOGGER.info(
                    "Epoch {:03d}, iteration {:03d}: train_loss: {:.3f}, train_accuracy: {:.3%}, test_loss: {:.3f}, "
                    "test_accuracy: {:.3%}, elapsed_time: {:.4f}".format(
                        iter_num, batch_num, sum_loss, train_accuracy_v,
                        test_loss_v, test_accuracy_v, elapsed_time))

                batch_num += 1

                if batch_num >= self.zcl_early_stop_batch:
                    return

            self.n_iter_ = iter_num
Exemplo n.º 7
0
    def fit(self, data_instances, validate_data=None):

        self._abnormal_detection(data_instances)
        self.check_abnormal_values(data_instances)
        self.init_schema(data_instances)

        validation_strategy = self.init_validation_strategy(
            data_instances, validate_data)
        self.model_weights = self._init_model_variables(data_instances)

        max_iter = self.max_iter
        # total_data_num = data_instances.count()
        mini_batch_obj = MiniBatch(data_inst=data_instances,
                                   batch_size=self.batch_size)
        model_weights = self.model_weights

        degree = 0
        self.prev_round_weights = copy.deepcopy(model_weights)

        while self.n_iter_ < max_iter + 1:
            batch_data_generator = mini_batch_obj.mini_batch_data_generator()

            self.optimizer.set_iters(self.n_iter_)
            if ((self.n_iter_ + 1) % self.aggregate_iters
                    == 0) or self.n_iter_ == max_iter:
                weight = self.aggregator.aggregate_then_get(
                    model_weights, degree=degree, suffix=self.n_iter_)

                self.model_weights = LogisticRegressionWeights(
                    weight.unboxed, self.fit_intercept)

                # store prev_round_weights after aggregation
                self.prev_round_weights = copy.deepcopy(self.model_weights)
                # send loss to arbiter
                loss = self._compute_loss(data_instances,
                                          self.prev_round_weights)
                self.aggregator.send_loss(loss,
                                          degree=degree,
                                          suffix=(self.n_iter_, ))
                degree = 0

                self.is_converged = self.aggregator.get_converge_status(
                    suffix=(self.n_iter_, ))
                LOGGER.info(
                    "n_iters: {}, loss: {} converge flag is :{}".format(
                        self.n_iter_, loss, self.is_converged))
                if self.is_converged or self.n_iter_ == max_iter:
                    break
                model_weights = self.model_weights

            batch_num = 0
            for batch_data in batch_data_generator:
                n = batch_data.count()
                # LOGGER.debug("In each batch, lr_weight: {}, batch_data count: {}".format(model_weights.unboxed, n))
                f = functools.partial(self.gradient_operator.compute_gradient,
                                      coef=model_weights.coef_,
                                      intercept=model_weights.intercept_,
                                      fit_intercept=self.fit_intercept)
                grad = batch_data.applyPartitions(f).reduce(
                    fate_operator.reduce_add)
                grad /= n
                # LOGGER.debug('iter: {}, batch_index: {}, grad: {}, n: {}'.format(
                #     self.n_iter_, batch_num, grad, n))

                if self.use_proximal:  # use proximal term
                    model_weights = self.optimizer.update_model(
                        model_weights,
                        grad=grad,
                        has_applied=False,
                        prev_round_weights=self.prev_round_weights)
                else:
                    model_weights = self.optimizer.update_model(
                        model_weights, grad=grad, has_applied=False)

                batch_num += 1
                degree += n

            validation_strategy.validate(self, self.n_iter_)
            self.n_iter_ += 1
        self.set_summary(self.get_model_summary())
Exemplo n.º 8
0
    def fit(self,
            data_instances,
            node2id,
            local_instances=None,
            common_nodes=None):
        """
        Train node embedding for role guest
        Parameters
        ----------
        data_instances: DTable of target node and label, input data
        node2id: a dict which can map node name to id
        """
        LOGGER.info("samples number:{}".format(data_instances.count()))
        LOGGER.info("Enter network embedding procedure:")
        self.n_node = len(node2id)
        LOGGER.info("Bank A has {} nodes".format(self.n_node))

        data_instances = data_instances.mapValues(HeteroNEGuest.load_data)
        LOGGER.info("Transform input data to train instance")

        public_key = federation.get(
            name=self.transfer_variable.paillier_pubkey.name,
            tag=self.transfer_variable.generate_transferid(
                self.transfer_variable.paillier_pubkey),
            idx=0)
        LOGGER.info("Get public_key from arbiter:{}".format(public_key))
        self.encrypt_operator.set_public_key(public_key)

        # hetero network embedding
        LOGGER.info("Generate mini-batch from input data")
        mini_batch_obj = MiniBatch(data_instances, batch_size=self.batch_size)
        batch_num = mini_batch_obj.batch_nums

        LOGGER.info("samples number:{}".format(data_instances.count()))
        if self.batch_size == -1:
            LOGGER.info(
                "batch size is -1, set it to the number of data in data_instances"
            )
            self.batch_size = data_instances.count()

        ##############
        # horizontal federated learning
        LOGGER.info("Generate mini-batch for local instances in guest")
        mini_batch_obj_local = MiniBatch(local_instances,
                                         batch_size=self.batch_size)
        local_batch_num = mini_batch_obj_local.batch_nums
        common_node_instances = eggroll.parallelize(
            ((node, node) for node in common_nodes),
            include_key=True,
            name='common_nodes')
        ##############

        batch_info = {'batch_size': self.batch_size, "batch_num": batch_num}
        LOGGER.info("batch_info:{}".format(batch_info))
        federation.remote(batch_info,
                          name=self.transfer_variable.batch_info.name,
                          tag=self.transfer_variable.generate_transferid(
                              self.transfer_variable.batch_info),
                          role=consts.HOST,
                          idx=0)
        LOGGER.info("Remote batch_info to Host")

        federation.remote(batch_info,
                          name=self.transfer_variable.batch_info.name,
                          tag=self.transfer_variable.generate_transferid(
                              self.transfer_variable.batch_info),
                          role=consts.ARBITER,
                          idx=0)
        LOGGER.info("Remote batch_info to Arbiter")

        self.encrypted_calculator = [
            EncryptModeCalculator(
                self.encrypt_operator,
                self.encrypted_mode_calculator_param.mode,
                self.encrypted_mode_calculator_param.re_encrypted_rate)
            for _ in range(batch_num)
        ]

        LOGGER.info("Start initialize model.")
        self.embedding_ = self.initializer.init_model((self.n_node, self.dim),
                                                      self.init_param_obj)
        LOGGER.info("Embedding shape={}".format(self.embedding_.shape))

        is_send_all_batch_index = False
        self.n_iter_ = 0
        index_data_inst_map = {}

        while self.n_iter_ < self.max_iter:
            LOGGER.info("iter:{}".format(self.n_iter_))

            #################
            local_batch_data_generator = mini_batch_obj_local.mini_batch_data_generator(
            )
            total_loss = 0
            local_batch_num = 0
            LOGGER.info("Enter the horizontally federated learning procedure:")
            for local_batch_data in local_batch_data_generator:
                n = local_batch_data.count()
                #LOGGER.info("Local batch data count:{}".format(n))
                E_Y = self.compute_local_embedding(local_batch_data,
                                                   self.embedding_, node2id)
                local_grads_e1, local_grads_e2, local_loss = self.local_gradient_operator.compute(
                    E_Y, 'E_1')
                local_grads_e1 = local_grads_e1.mapValues(
                    lambda g: self.local_optimizer.apply_gradients(g / n))
                local_grads_e2 = local_grads_e2.mapValues(
                    lambda g: self.local_optimizer.apply_gradients(g / n))
                e1id_join_grads = local_batch_data.join(
                    local_grads_e1, lambda v, g: (node2id[v[0]], g))
                e2id_join_grads = local_batch_data.join(
                    local_grads_e2, lambda v, g: (node2id[v[1]], g))
                self.update_model(e1id_join_grads)
                self.update_model(e2id_join_grads)

                local_loss = local_loss / n
                local_batch_num += 1
                total_loss += local_loss
                #LOGGER.info("gradient count:{}".format(e1id_join_grads.count()))

            guest_common_embedding = common_node_instances.mapValues(
                lambda node: self.embedding_[node2id[node]])
            federation.remote(
                guest_common_embedding,
                name=self.transfer_variable.guest_common_embedding.name,
                tag=self.transfer_variable.generate_transferid(
                    self.transfer_variable.guest_common_embedding,
                    self.n_iter_, 0),
                role=consts.ARBITER,
                idx=0)
            LOGGER.info("Remote the embedding of common node to arbiter!")

            common_embedding = federation.get(
                name=self.transfer_variable.common_embedding.name,
                tag=self.transfer_variable.generate_transferid(
                    self.transfer_variable.common_embedding, self.n_iter_, 0),
                idx=0)
            LOGGER.info(
                "Get the aggregated embedding of common node from arbiter!")

            self.update_common_nodes(common_embedding, common_nodes, node2id)

            total_loss /= local_batch_num
            LOGGER.info(
                "Iter {}, horizontally feaderated learning loss: {}".format(
                    self.n_iter_, total_loss))

            #################

            # verticallly feaderated learning
            # each iter will get the same batch_data_generator
            LOGGER.info("Enter the vertically federated learning:")
            batch_data_generator = mini_batch_obj.mini_batch_data_generator(
                result='index')

            batch_index = 0
            for batch_data_index in batch_data_generator:
                LOGGER.info("batch:{}".format(batch_index))

                # only need to send one times
                if not is_send_all_batch_index:
                    LOGGER.info("remote mini-batch index to Host")
                    federation.remote(
                        batch_data_index,
                        name=self.transfer_variable.batch_data_index.name,
                        tag=self.transfer_variable.generate_transferid(
                            self.transfer_variable.batch_data_index,
                            self.n_iter_, batch_index),
                        role=consts.HOST,
                        idx=0)
                    if batch_index >= mini_batch_obj.batch_nums - 1:
                        is_send_all_batch_index = True

                # in order to avoid joining in next iteration
                # Get mini-batch train data
                if len(index_data_inst_map) < batch_num:
                    batch_data_inst = data_instances.join(
                        batch_data_index, lambda data_inst, index: data_inst)
                    index_data_inst_map[batch_index] = batch_data_inst
                else:
                    batch_data_inst = index_data_inst_map[batch_index]

                # For inductive learning: transform node attributes to node embedding
                # self.transform(batch_data_inst)
                self.guest_forward = self.compute_forward(
                    batch_data_inst, self.embedding_, node2id, batch_index)

                host_forward = federation.get(
                    name=self.transfer_variable.host_forward_dict.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.host_forward_dict, self.n_iter_,
                        batch_index),
                    idx=0)
                LOGGER.info("Get host_forward from host")
                aggregate_forward_res = self.aggregate_forward(host_forward)
                en_aggregate_ee = aggregate_forward_res.mapValues(
                    lambda v: v[0])
                en_aggregate_ee_square = aggregate_forward_res.mapValues(
                    lambda v: v[1])

                # compute [[d]]
                if self.gradient_operator is None:
                    self.gradient_operator = HeteroNetworkEmbeddingGradient(
                        self.encrypt_operator)
                fore_gradient = self.gradient_operator.compute_fore_gradient(
                    batch_data_inst, en_aggregate_ee)

                host_gradient = self.gradient_operator.compute_gradient(
                    self.guest_forward.mapValues(
                        lambda v: Instance(features=v[1])), fore_gradient)
                federation.remote(
                    host_gradient,
                    name=self.transfer_variable.host_gradient.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.host_gradient, self.n_iter_,
                        batch_index),
                    role=consts.ARBITER,
                    idx=0)
                LOGGER.info("Remote host_gradient to arbiter")

                composed_data_inst = host_forward.join(
                    batch_data_inst,
                    lambda hf, d: Instance(features=hf[1], label=d.label))
                guest_gradient, loss = self.gradient_operator.compute_gradient_and_loss(
                    composed_data_inst, fore_gradient, en_aggregate_ee,
                    en_aggregate_ee_square)
                federation.remote(
                    guest_gradient,
                    name=self.transfer_variable.guest_gradient.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.guest_gradient, self.n_iter_,
                        batch_index),
                    role=consts.ARBITER,
                    idx=0)
                LOGGER.info("Remote guest_gradient to arbiter")

                optim_guest_gradient = federation.get(
                    name=self.transfer_variable.guest_optim_gradient.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.guest_optim_gradient,
                        self.n_iter_, batch_index),
                    idx=0)
                LOGGER.info("Get optim_guest_gradient from arbiter")

                # update node embedding
                LOGGER.info("Update node embedding")
                nodeid_join_gradient = batch_data_inst.join(
                    optim_guest_gradient, lambda instance, gradient:
                    (node2id[instance.features], gradient))
                self.update_model(nodeid_join_gradient)

                # update local model that transform attribute to node embedding
                training_info = {
                    'iteration': self.n_iter_,
                    'batch_index': batch_index
                }
                self.update_local_model(fore_gradient, batch_data_inst,
                                        self.embedding_, **training_info)

                # loss need to be encrypted !!!!!!

                federation.remote(
                    loss,
                    name=self.transfer_variable.loss.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.loss, self.n_iter_,
                        batch_index),
                    role=consts.ARBITER,
                    idx=0)
                LOGGER.info("Remote loss to arbiter")

                # is converge of loss in arbiter
                batch_index += 1

                # remove temporary resource
                rubbish_list = [
                    host_forward, aggregate_forward_res, en_aggregate_ee,
                    en_aggregate_ee_square, fore_gradient, self.guest_forward
                ]
                rubbish_clear(rubbish_list)

            ##########
            guest_common_embedding = common_node_instances.mapValues(
                lambda node: self.embedding_[node2id[node]])
            federation.remote(
                guest_common_embedding,
                name=self.transfer_variable.guest_common_embedding.name,
                tag=self.transfer_variable.generate_transferid(
                    self.transfer_variable.guest_common_embedding,
                    self.n_iter_, 1),
                role=consts.ARBITER,
                idx=0)

            common_embedding = federation.get(
                name=self.transfer_variable.common_embedding.name,
                tag=self.transfer_variable.generate_transferid(
                    self.transfer_variable.common_embedding, self.n_iter_, 1),
                idx=0)

            self.update_common_nodes(common_embedding, common_nodes, node2id)
            ##########

            is_stopped = federation.get(
                name=self.transfer_variable.is_stopped.name,
                tag=self.transfer_variable.generate_transferid(
                    self.transfer_variable.is_stopped, self.n_iter_),
                idx=0)

            LOGGER.info("Get is_stop flag from arbiter:{}".format(is_stopped))

            self.n_iter_ += 1
            if is_stopped:
                LOGGER.info(
                    "Get stop signal from arbiter, model is converged, iter:{}"
                    .format(self.n_iter_))
                break

        embedding_table = eggroll.table(name='guest',
                                        namespace='node_embedding',
                                        partition=10)
        id2node = dict(zip(node2id.values(), node2id.keys()))
        for id, embedding in enumerate(self.embedding_):
            embedding_table.put(id2node[id], embedding)
        embedding_table.save_as(name='guest',
                                namespace='node_embedding',
                                partition=10)
        LOGGER.info("Reach max iter {}, train model finish!".format(
            self.max_iter))
Exemplo n.º 9
0
    def fit(self, data_instances, validate_data=None):
        LOGGER.debug("Start data count: {}".format(data_instances.count()))

        self._abnormal_detection(data_instances)
        self.init_schema(data_instances)
        validation_strategy = self.init_validation_strategy(
            data_instances, validate_data)

        self.model_weights = self._init_model_variables(data_instances)

        LOGGER.debug("After init, model_weights: {}".format(
            self.model_weights.unboxed))

        mini_batch_obj = MiniBatch(data_inst=data_instances,
                                   batch_size=self.batch_size)

        model_weights = self.model_weights
        degree = 0
        while self.n_iter_ < self.max_iter + 1:
            LOGGER.info("iter:{}".format(self.n_iter_))
            batch_data_generator = mini_batch_obj.mini_batch_data_generator()

            if (self.n_iter_ > 0 and self.n_iter_ % self.aggregate_iters
                    == 0) or self.n_iter_ == self.max_iter:
                weight = self.aggregator.aggregate_then_get(
                    weight, degree=degree, suffix=self.n_iter_)

                weight._weights = np.array(weight._weights)
                # This weight is transferable Weight, should get the parameter back
                self.model_weights.update(weight)

                LOGGER.debug(
                    "Before aggregate: {}, degree: {} after aggregated: {}".
                    format(model_weights.unboxed / degree, degree,
                           self.model_weights.unboxed))

                loss = self._compute_loss(data_instances)
                self.aggregator.send_loss(loss,
                                          degree=degree,
                                          suffix=(self.n_iter_, ))
                degree = 0

                self.is_converged = self.aggregator.get_converge_status(
                    suffix=(self.n_iter_, ))
                LOGGER.info("n_iters: {}, is_converge: {}".format(
                    self.n_iter_, self.is_converged))
                if self.is_converged:
                    break
                model_weights = self.model_weights

            batch_num = 0
            for batch_data in batch_data_generator:
                n = batch_data.count()
                LOGGER.debug("before compute_gradient,w_:{},embed_:{}".format(
                    model_weights.w_, model_weights.embed_))
                f = functools.partial(self.gradient_operator.compute_gradient,
                                      w=model_weights.w_,
                                      embed=model_weights.embed_,
                                      intercept=model_weights.intercept_,
                                      fit_intercept=self.fit_intercept)
                grad = batch_data.mapPartitions(f).reduce(
                    fate_operator.reduce_add)
                grad /= n
                weight = self.optimizer.update_model(model_weights,
                                                     grad,
                                                     has_applied=False)
                weight._weights = np.array(weight._weights)
                model_weights.update(weight)
                batch_num += 1
                degree += n

            validation_strategy.validate(self, self.n_iter_)
            self.n_iter_ += 1

        LOGGER.info("Finish Training task, total iters: {}".format(
            self.n_iter_))
Exemplo n.º 10
0
    def fit(self, data_instances, validate_data=None):

        self._abnormal_detection(data_instances)
        self.init_schema(data_instances)

        validation_strategy = self.init_validation_strategy(
            data_instances, validate_data)
        self.model_weights = self._init_model_variables(data_instances)

        mini_batch_obj = MiniBatch(data_inst=data_instances,
                                   batch_size=self.batch_size)
        model_weights = self.model_weights

        degree = 0
        while self.n_iter_ < self.max_iter + 1:
            LOGGER.info("iter:{}".format(self.n_iter_))
            batch_data_generator = mini_batch_obj.mini_batch_data_generator()

            self.optimizer.set_iters(self.n_iter_)
            if (self.n_iter_ > 0 and self.n_iter_ % self.aggregate_iters
                    == 0) or self.n_iter_ == self.max_iter:
                # This loop will run after weight has been created,weight will be in LRweights
                weight = self.aggregator.aggregate_then_get(
                    weight, degree=degree, suffix=self.n_iter_)

                weight._weights = np.array(weight._weights)
                # This weight is transferable Weight, should get the parameter back
                self.model_weights.update(weight)

                LOGGER.debug(
                    "Before aggregate: {}, degree: {} after aggregated: {}".
                    format(model_weights.unboxed / degree, degree,
                           self.model_weights.unboxed))

                loss = self._compute_loss(data_instances)
                self.aggregator.send_loss(loss,
                                          degree=degree,
                                          suffix=(self.n_iter_, ))
                degree = 0

                self.is_converged = self.aggregator.get_converge_status(
                    suffix=(self.n_iter_, ))
                LOGGER.info(
                    "n_iters: {}, loss: {} converge flag is :{}".format(
                        self.n_iter_, loss, self.is_converged))
                if self.is_converged:
                    break
                model_weights = self.model_weights

            batch_num = 0
            for batch_data in batch_data_generator:
                n = batch_data.count()
                LOGGER.debug(
                    "In each batch, fm_weight: {}, batch_data count: {},w:{},embed:{}"
                    .format(model_weights.unboxed, n, model_weights.w_,
                            model_weights.embed_))
                f = functools.partial(self.gradient_operator.compute_gradient,
                                      w=model_weights.w_,
                                      embed=model_weights.embed_,
                                      intercept=model_weights.intercept_,
                                      fit_intercept=self.fit_intercept)
                grad = batch_data.mapPartitions(f).reduce(
                    fate_operator.reduce_add)
                grad /= n
                LOGGER.debug(
                    'iter: {}, batch_index: {}, grad: {}, n: {}'.format(
                        self.n_iter_, batch_num, grad, n))

                weight = self.optimizer.update_model(model_weights,
                                                     grad,
                                                     has_applied=False)
                weight._weights = np.array(weight._weights)
                model_weights.update(weight)
                batch_num += 1
                degree += n

            validation_strategy.validate(self, self.n_iter_)
            self.n_iter_ += 1
Exemplo n.º 11
0
    def fit(self, data_instances):
        LOGGER.info("Enter hetero_lr_guest fit")
        data_instances = data_instances.mapValues(HeteroLRGuest.load_data)

        public_key = federation.get(
            name=self.transfer_variable.paillier_pubkey.name,
            tag=self.transfer_variable.generate_transferid(
                self.transfer_variable.paillier_pubkey),
            idx=0)
        LOGGER.info("Get public_key from arbiter:{}".format(public_key))
        self.encrypt_operator.set_public_key(public_key)

        LOGGER.info("Generate mini-batch from input data")
        mini_batch_obj = MiniBatch(data_instances, batch_size=self.batch_size)
        batch_info = {
            "batch_size": self.batch_size,
            "batch_num": mini_batch_obj.batch_nums
        }
        LOGGER.info("batch_info:" + str(batch_info))
        federation.remote(batch_info,
                          name=self.transfer_variable.batch_info.name,
                          tag=self.transfer_variable.generate_transferid(
                              self.transfer_variable.batch_info),
                          role=consts.HOST,
                          idx=0)
        LOGGER.info("Remote batch_info to Host")
        federation.remote(batch_info,
                          name=self.transfer_variable.batch_info.name,
                          tag=self.transfer_variable.generate_transferid(
                              self.transfer_variable.batch_info),
                          role=consts.ARBITER,
                          idx=0)
        LOGGER.info("Remote batch_info to Arbiter")

        LOGGER.info("Start initialize model.")
        LOGGER.info("fit_intercept:{}".format(
            self.init_param_obj.fit_intercept))
        model_shape = self.get_features_shape(data_instances)
        weight = self.initializer.init_model(model_shape,
                                             init_params=self.init_param_obj)
        if self.init_param_obj.fit_intercept is True:
            self.coef_ = weight[:-1]
            self.intercept_ = weight[-1]
        else:
            self.coef_ = weight

        is_stopped = False
        is_send_all_batch_index = False
        self.n_iter_ = 0
        while self.n_iter_ < self.max_iter:
            LOGGER.info("iter:{}".format(self.n_iter_))
            batch_data_generator = mini_batch_obj.mini_batch_index_generator(
                data_inst=data_instances, batch_size=self.batch_size)
            batch_index = 0
            for batch_data_index in batch_data_generator:
                LOGGER.info("batch:{}".format(batch_index))
                if not is_send_all_batch_index:
                    LOGGER.info("remote mini-batch index to Host")
                    federation.remote(
                        batch_data_index,
                        name=self.transfer_variable.batch_data_index.name,
                        tag=self.transfer_variable.generate_transferid(
                            self.transfer_variable.batch_data_index,
                            self.n_iter_, batch_index),
                        role=consts.HOST,
                        idx=0)
                    if batch_index >= mini_batch_obj.batch_nums - 1:
                        is_send_all_batch_index = True

                # Get mini-batch train data
                batch_data_inst = data_instances.join(
                    batch_data_index, lambda data_inst, index: data_inst)

                # guest/host forward
                self.compute_forward(batch_data_inst, self.coef_,
                                     self.intercept_)
                host_forward = federation.get(
                    name=self.transfer_variable.host_forward_dict.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.host_forward_dict, self.n_iter_,
                        batch_index),
                    idx=0)
                LOGGER.info("Get host_forward from host")
                aggregate_forward_res = self.aggregate_forward(host_forward)
                en_aggregate_wx = aggregate_forward_res.mapValues(
                    lambda v: v[0])
                en_aggregate_wx_square = aggregate_forward_res.mapValues(
                    lambda v: v[1])

                # compute [[d]]
                if self.gradient_operator is None:
                    self.gradient_operator = HeteroLogisticGradient(
                        self.encrypt_operator)
                fore_gradient = self.gradient_operator.compute_fore_gradient(
                    batch_data_inst, en_aggregate_wx)
                federation.remote(
                    fore_gradient,
                    name=self.transfer_variable.fore_gradient.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.fore_gradient, self.n_iter_,
                        batch_index),
                    role=consts.HOST,
                    idx=0)
                LOGGER.info("Remote fore_gradient to Host")
                # compute guest gradient and loss
                guest_gradient, loss = self.gradient_operator.compute_gradient_and_loss(
                    batch_data_inst, fore_gradient, en_aggregate_wx,
                    en_aggregate_wx_square, self.fit_intercept)

                # loss regulation if necessary
                if self.updater is not None:
                    guest_loss_regular = self.updater.loss_norm(self.coef_)
                    loss += self.encrypt_operator.encrypt(guest_loss_regular)

                federation.remote(
                    guest_gradient,
                    name=self.transfer_variable.guest_gradient.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.guest_gradient, self.n_iter_,
                        batch_index),
                    role=consts.ARBITER,
                    idx=0)
                LOGGER.info("Remote guest_gradient to arbiter")

                optim_guest_gradient = federation.get(
                    name=self.transfer_variable.guest_optim_gradient.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.guest_optim_gradient,
                        self.n_iter_, batch_index),
                    idx=0)
                LOGGER.info("Get optim_guest_gradient from arbiter")

                # update model
                LOGGER.info("update_model")
                self.update_model(optim_guest_gradient)

                # Get loss regulation from Host if regulation is set
                if self.updater is not None:
                    en_host_loss_regular = federation.get(
                        name=self.transfer_variable.host_loss_regular.name,
                        tag=self.transfer_variable.generate_transferid(
                            self.transfer_variable.host_loss_regular,
                            self.n_iter_, batch_index),
                        idx=0)
                    LOGGER.info("Get host_loss_regular from Host")
                    loss += en_host_loss_regular

                federation.remote(
                    loss,
                    name=self.transfer_variable.loss.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.loss, self.n_iter_,
                        batch_index),
                    role=consts.ARBITER,
                    idx=0)
                LOGGER.info("Remote loss to arbiter")

                # is converge of loss in arbiter
                is_stopped = federation.get(
                    name=self.transfer_variable.is_stopped.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.is_stopped, self.n_iter_,
                        batch_index),
                    idx=0)
                LOGGER.info(
                    "Get is_stop flag from arbiter:{}".format(is_stopped))
                batch_index += 1
                if is_stopped:
                    LOGGER.info(
                        "Get stop signal from arbiter, model is converged, iter:{}"
                        .format(self.n_iter_))
                    break

            self.n_iter_ += 1
            if is_stopped:
                break
        LOGGER.info("Reach max iter {}, train model finish!".format(
            self.max_iter))
Exemplo n.º 12
0
    def fit_binary(self, data_instances, validate_data=None):
        self.aggregator = aggregator.Host()
        self.aggregator.register_aggregator(self.transfer_variable)

        self._client_check_data(data_instances)
        self.callback_list.on_train_begin(data_instances, validate_data)

        pubkey = self.cipher.gen_paillier_pubkey(enable=self.use_encrypt, suffix=('fit',))
        if self.use_encrypt:
            self.cipher_operator.set_public_key(pubkey)

        if not self.component_properties.is_warm_start:
            self.model_weights = self._init_model_variables(data_instances)
            if self.use_encrypt:
                w = self.cipher_operator.encrypt_list(self.model_weights.unboxed)
            else:
                w = list(self.model_weights.unboxed)
            self.model_weights = LogisticRegressionWeights(w, self.model_weights.fit_intercept)
        else:
            self.callback_warm_start_init_iter(self.n_iter_)

        # LOGGER.debug("After init, model_weights: {}".format(self.model_weights.unboxed))

        mini_batch_obj = MiniBatch(data_inst=data_instances, batch_size=self.batch_size)

        total_batch_num = mini_batch_obj.batch_nums

        if self.use_encrypt:
            re_encrypt_times = (total_batch_num - 1) // self.re_encrypt_batches + 1
            # LOGGER.debug("re_encrypt_times is :{}, batch_size: {}, total_batch_num: {}, re_encrypt_batches: {}".format(
            #     re_encrypt_times, self.batch_size, total_batch_num, self.re_encrypt_batches))
            self.cipher.set_re_cipher_time(re_encrypt_times)

        # total_data_num = data_instances.count()
        # LOGGER.debug("Current data count: {}".format(total_data_num))

        model_weights = self.model_weights
        self.prev_round_weights = copy.deepcopy(model_weights)
        degree = 0
        while self.n_iter_ < self.max_iter + 1:
            self.callback_list.on_epoch_begin(self.n_iter_)

            batch_data_generator = mini_batch_obj.mini_batch_data_generator()
            self.optimizer.set_iters(self.n_iter_)

            self.optimizer.set_iters(self.n_iter_)
            if ((self.n_iter_ + 1) % self.aggregate_iters == 0) or self.n_iter_ == self.max_iter:
                weight = self.aggregator.aggregate_then_get(model_weights, degree=degree,
                                                            suffix=self.n_iter_)
                # LOGGER.debug("Before aggregate: {}, degree: {} after aggregated: {}".format(
                #     model_weights.unboxed / degree,
                #     degree,
                #     weight.unboxed))
                self.model_weights = LogisticRegressionWeights(weight.unboxed, self.fit_intercept)
                if not self.use_encrypt:
                    loss = self._compute_loss(data_instances, self.prev_round_weights)
                    self.aggregator.send_loss(loss, degree=degree, suffix=(self.n_iter_,))
                    LOGGER.info("n_iters: {}, loss: {}".format(self.n_iter_, loss))
                degree = 0
                self.is_converged = self.aggregator.get_converge_status(suffix=(self.n_iter_,))
                LOGGER.info("n_iters: {}, is_converge: {}".format(self.n_iter_, self.is_converged))
                if self.is_converged or self.n_iter_ == self.max_iter:
                    break
                model_weights = self.model_weights

            batch_num = 0
            for batch_data in batch_data_generator:
                n = batch_data.count()
                degree += n
                LOGGER.debug('before compute_gradient')
                f = functools.partial(self.gradient_operator.compute_gradient,
                                      coef=model_weights.coef_,
                                      intercept=model_weights.intercept_,
                                      fit_intercept=self.fit_intercept)
                grad = batch_data.applyPartitions(f).reduce(fate_operator.reduce_add)
                grad /= n
                if self.use_proximal:  # use additional proximal term
                    model_weights = self.optimizer.update_model(model_weights,
                                                                grad=grad,
                                                                has_applied=False,
                                                                prev_round_weights=self.prev_round_weights)
                else:
                    model_weights = self.optimizer.update_model(model_weights,
                                                                grad=grad,
                                                                has_applied=False)

                if self.use_encrypt and batch_num % self.re_encrypt_batches == 0:
                    LOGGER.debug("Before accept re_encrypted_model, batch_iter_num: {}".format(batch_num))
                    w = self.cipher.re_cipher(w=model_weights.unboxed,
                                              iter_num=self.n_iter_,
                                              batch_iter_num=batch_num)
                    model_weights = LogisticRegressionWeights(w, self.fit_intercept)
                batch_num += 1

            # validation_strategy.validate(self, self.n_iter_)
            self.callback_list.on_epoch_end(self.n_iter_)
            self.n_iter_ += 1
            if self.stop_training:
                break

        self.set_summary(self.get_model_summary())
        LOGGER.info("Finish Training task, total iters: {}".format(self.n_iter_))
Exemplo n.º 13
0
    def fit(self, data_instances):
        """
        Train lr model of role guest
        Parameters
        ----------
        data_instances: DTable of Instance, input data
        """

        LOGGER.info("Enter hetero_lr_guest fit")
        self._abnormal_detection(data_instances)

        self.header = self.get_header(data_instances)
        data_instances = data_instances.mapValues(HeteroLRGuest.load_data)

        # 获得密钥
        public_key = federation.get(
            name=self.transfer_variable.paillier_pubkey.name,
            tag=self.transfer_variable.generate_transferid(
                self.transfer_variable.paillier_pubkey),
            idx=0)
        LOGGER.info("Get public_key from arbiter:{}".format(public_key))
        self.encrypt_operator.set_public_key(public_key)

        LOGGER.info("Generate mini-batch from input data")
        mini_batch_obj = MiniBatch(data_instances, batch_size=self.batch_size)
        batch_num = mini_batch_obj.batch_nums
        if self.batch_size == -1:
            LOGGER.info(
                "batch size is -1, set it to the number of data in data_instances"
            )
            self.batch_size = data_instances.count()

        batch_info = {"batch_size": self.batch_size, "batch_num": batch_num}
        LOGGER.info("batch_info:{}".format(batch_info))
        federation.remote(batch_info,
                          name=self.transfer_variable.batch_info.name,
                          tag=self.transfer_variable.generate_transferid(
                              self.transfer_variable.batch_info),
                          role=consts.HOST,
                          idx=0)
        LOGGER.info("Remote batch_info to Host")
        federation.remote(batch_info,
                          name=self.transfer_variable.batch_info.name,
                          tag=self.transfer_variable.generate_transferid(
                              self.transfer_variable.batch_info),
                          role=consts.ARBITER,
                          idx=0)
        LOGGER.info("Remote batch_info to Arbiter")

        self.encrypted_calculator = [
            EncryptModeCalculator(
                self.encrypt_operator,
                self.encrypted_mode_calculator_param.mode,
                self.encrypted_mode_calculator_param.re_encrypted_rate)
            for _ in range(batch_num)
        ]

        LOGGER.info("Start initialize model.")
        LOGGER.info("fit_intercept:{}".format(
            self.init_param_obj.fit_intercept))
        model_shape = self.get_features_shape(data_instances)
        weight = self.initializer.init_model(model_shape,
                                             init_params=self.init_param_obj)
        if self.init_param_obj.fit_intercept is True:
            self.coef_ = weight[:-1]
            self.intercept_ = weight[-1]
        else:
            self.coef_ = weight

        is_send_all_batch_index = False
        self.n_iter_ = 0
        index_data_inst_map = {}

        while self.n_iter_ < self.max_iter:
            LOGGER.info("iter:{}".format(self.n_iter_))
            # each iter will get the same batch_data_generator
            batch_data_generator = mini_batch_obj.mini_batch_data_generator(
                result='index')

            batch_index = 0
            for batch_data_index in batch_data_generator:
                LOGGER.info("batch:{}".format(batch_index))
                if not is_send_all_batch_index:
                    LOGGER.info("remote mini-batch index to Host")
                    federation.remote(
                        batch_data_index,
                        name=self.transfer_variable.batch_data_index.name,
                        tag=self.transfer_variable.generate_transferid(
                            self.transfer_variable.batch_data_index,
                            self.n_iter_, batch_index),
                        role=consts.HOST,
                        idx=0)
                    if batch_index >= mini_batch_obj.batch_nums - 1:
                        is_send_all_batch_index = True

                # Get mini-batch train data
                if len(index_data_inst_map) < batch_num:
                    batch_data_inst = data_instances.join(
                        batch_data_index, lambda data_inst, index: data_inst)
                    index_data_inst_map[batch_index] = batch_data_inst
                else:
                    batch_data_inst = index_data_inst_map[batch_index]

                # transforms features of raw input 'batch_data_inst' into more representative features 'batch_feat_inst'
                batch_feat_inst = self.transform(batch_data_inst)

                # guest/host forward
                self.compute_forward(batch_feat_inst, self.coef_,
                                     self.intercept_, batch_index)
                host_forward = federation.get(
                    name=self.transfer_variable.host_forward_dict.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.host_forward_dict, self.n_iter_,
                        batch_index),
                    idx=0)
                LOGGER.info("Get host_forward from host")
                aggregate_forward_res = self.aggregate_forward(host_forward)
                en_aggregate_wx = aggregate_forward_res.mapValues(
                    lambda v: v[0])
                en_aggregate_wx_square = aggregate_forward_res.mapValues(
                    lambda v: v[1])

                # compute [[d]]
                if self.gradient_operator is None:
                    self.gradient_operator = HeteroLogisticGradient(
                        self.encrypt_operator)
                fore_gradient = self.gradient_operator.compute_fore_gradient(
                    batch_feat_inst, en_aggregate_wx)
                federation.remote(
                    fore_gradient,
                    name=self.transfer_variable.fore_gradient.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.fore_gradient, self.n_iter_,
                        batch_index),
                    role=consts.HOST,
                    idx=0)

                LOGGER.info("Remote fore_gradient to Host")
                # compute guest gradient and loss
                guest_gradient, loss = self.gradient_operator.compute_gradient_and_loss(
                    batch_feat_inst, fore_gradient, en_aggregate_wx,
                    en_aggregate_wx_square, self.fit_intercept)

                # loss regulation if necessary
                if self.updater is not None:
                    guest_loss_regular = self.updater.loss_norm(self.coef_)
                    loss += self.encrypt_operator.encrypt(guest_loss_regular)

                federation.remote(
                    guest_gradient,
                    name=self.transfer_variable.guest_gradient.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.guest_gradient, self.n_iter_,
                        batch_index),
                    role=consts.ARBITER,
                    idx=0)
                LOGGER.info("Remote guest_gradient to arbiter")

                optim_guest_gradient = federation.get(
                    name=self.transfer_variable.guest_optim_gradient.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.guest_optim_gradient,
                        self.n_iter_, batch_index),
                    idx=0)
                LOGGER.info("Get optim_guest_gradient from arbiter")

                # update model
                LOGGER.info("update_model")
                self.update_model(optim_guest_gradient)

                # update local model that transforms features of raw input 'batch_data_inst'
                training_info = {
                    "iteration": self.n_iter_,
                    "batch_index": batch_index
                }
                self.update_local_model(fore_gradient, batch_data_inst,
                                        self.coef_, **training_info)

                # Get loss regulation from Host if regulation is set
                if self.updater is not None:
                    en_host_loss_regular = federation.get(
                        name=self.transfer_variable.host_loss_regular.name,
                        tag=self.transfer_variable.generate_transferid(
                            self.transfer_variable.host_loss_regular,
                            self.n_iter_, batch_index),
                        idx=0)
                    LOGGER.info("Get host_loss_regular from Host")
                    loss += en_host_loss_regular

                federation.remote(
                    loss,
                    name=self.transfer_variable.loss.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.loss, self.n_iter_,
                        batch_index),
                    role=consts.ARBITER,
                    idx=0)
                LOGGER.info("Remote loss to arbiter")

                # is converge of loss in arbiter
                batch_index += 1

                # temporary resource recovery and will be removed in the future
                rubbish_list = [
                    host_forward, aggregate_forward_res, en_aggregate_wx,
                    en_aggregate_wx_square, fore_gradient, self.guest_forward
                ]
                rubbish_clear(rubbish_list)

            is_stopped = federation.get(
                name=self.transfer_variable.is_stopped.name,
                tag=self.transfer_variable.generate_transferid(
                    self.transfer_variable.is_stopped, self.n_iter_,
                    batch_index),
                idx=0)
            LOGGER.info("Get is_stop flag from arbiter:{}".format(is_stopped))

            self.n_iter_ += 1
            if is_stopped:
                LOGGER.info(
                    "Get stop signal from arbiter, model is converged, iter:{}"
                    .format(self.n_iter_))
                break

        LOGGER.info("Reach max iter {}, train model finish!".format(
            self.max_iter))
Exemplo n.º 14
0
    def fit(self, data_instances, validate_data=None):
        LOGGER.debug("Start data count: {}".format(data_instances.count()))

        self._abnormal_detection(data_instances)
        self.init_schema(data_instances)
        validation_strategy = self.init_validation_strategy(data_instances, validate_data)

        pubkey = self.cipher.gen_paillier_pubkey(enable=self.use_encrypt, suffix=('fit',))
        if self.use_encrypt:
            self.cipher_operator.set_public_key(pubkey)

        self.model_weights = self._init_model_variables(data_instances)
        w = self.cipher_operator.encrypt_list(self.model_weights.unboxed)
        self.model_weights = LogisticRegressionWeights(w, self.model_weights.fit_intercept)

        LOGGER.debug("After init, model_weights: {}".format(self.model_weights.unboxed))

        mini_batch_obj = MiniBatch(data_inst=data_instances, batch_size=self.batch_size)

        total_batch_num = mini_batch_obj.batch_nums

        if self.use_encrypt:
            re_encrypt_times = total_batch_num // self.re_encrypt_batches + 1
            LOGGER.debug("re_encrypt_times is :{}, batch_size: {}, total_batch_num: {}, re_encrypt_batches: {}".format(
                re_encrypt_times, self.batch_size, total_batch_num, self.re_encrypt_batches))
            self.cipher.set_re_cipher_time(re_encrypt_times)

        total_data_num = data_instances.count()
        LOGGER.debug("Current data count: {}".format(total_data_num))

        model_weights = self.model_weights
        degree = 0

        self.__synchronize_encryption()
        self.zcl_idx, self.zcl_num_party = self.transfer_variable.num_party.get(idx=0, suffix=('train',))
        LOGGER.debug("party num:" + str(self.zcl_num_party))
        self.__init_model()

        self.train_loss_results = []
        self.train_accuracy_results = []
        self.test_loss_results = []
        self.test_accuracy_results = []

        for iter_num in range(self.max_iter):
            # mini-batch
            LOGGER.debug("In iter: {}".format(iter_num))
            # batch_data_generator = self.mini_batch_obj.mini_batch_data_generator()
            batch_num = 0
            total_loss = 0
            epoch_train_loss_avg = tfe.metrics.Mean()
            epoch_train_accuracy = tfe.metrics.Accuracy()

            for train_x, train_y in self.zcl_dataset:
                LOGGER.info("Staring batch {}".format(batch_num))
                start_t = time.time()
                loss_value, grads = self.__grad(self.zcl_model, train_x, train_y)
                loss_value = loss_value.numpy()
                grads = [x.numpy() for x in grads]
                LOGGER.info("Start encrypting")
                loss_value = batch_encryption.encrypt(self.zcl_encrypt_operator.get_public_key(), loss_value)
                grads = [batch_encryption.encrypt_matrix(self.zcl_encrypt_operator.get_public_key(), x) for x in grads]
                LOGGER.info("Finish encrypting")
                grads = Gradients(grads)
                self.transfer_variable.host_grad.remote(obj=grads.for_remote(), role=consts.ARBITER, idx=0, suffix=(iter_num, batch_num))
                LOGGER.info("Sent grads")
                self.transfer_variable.host_loss.remote(obj=loss_value, role=consts.ARBITER, idx=0, suffix=(iter_num, batch_num))
                LOGGER.info("Sent loss")

                sum_grads = self.transfer_variable.aggregated_grad.get(idx=0, suffix=(iter_num, batch_num))
                LOGGER.info("Got grads")
                sum_loss = self.transfer_variable.aggregated_loss.get(idx=0, suffix=(iter_num, batch_num))
                LOGGER.info("Got loss")

                sum_loss = batch_encryption.decrypt(self.zcl_encrypt_operator.get_privacy_key(), sum_loss)
                sum_grads = [
                    batch_encryption.decrypt_matrix(self.zcl_encrypt_operator.get_privacy_key(), x).astype(np.float32)
                    for x
                    in sum_grads.unboxed]
                LOGGER.info("Finish decrypting")

                # sum_grads = np.array(sum_grads) / self.zcl_num_party

                self.zcl_optimizer.apply_gradients(zip(sum_grads, self.zcl_model.trainable_variables),
                                                   self.zcl_global_step)

                elapsed_time = time.time() - start_t
                # epoch_train_loss_avg(loss_value)
                # epoch_train_accuracy(tf.argmax(self.zcl_model(train_x), axis=1, output_type=tf.int32),
                #                      train_y)
                self.train_loss_results.append(sum_loss)
                train_accuracy_v = accuracy_score(train_y,
                                                  tf.argmax(self.zcl_model(train_x), axis=1, output_type=tf.int32))
                self.train_accuracy_results.append(train_accuracy_v)
                test_loss_v = self.__loss(self.zcl_model, self.zcl_x_test, self.zcl_y_test)
                self.test_loss_results.append(test_loss_v)
                test_accuracy_v = accuracy_score(self.zcl_y_test,
                                                 tf.argmax(self.zcl_model(self.zcl_x_test), axis=1,
                                                           output_type=tf.int32))
                self.test_accuracy_results.append(test_accuracy_v)

                LOGGER.info(
                    "Epoch {:03d}, iteration {:03d}: train_loss: {:.3f}, train_accuracy: {:.3%}, test_loss: {:.3f}, "
                    "test_accuracy: {:.3%}, elapsed_time: {:.4f}".format(
                        iter_num,
                        batch_num,
                        sum_loss,
                        train_accuracy_v,
                        test_loss_v,
                        test_accuracy_v,
                        elapsed_time)
                )

                batch_num += 1

                if batch_num >= self.zcl_early_stop_batch:
                    return

            self.n_iter_ = iter_num
Exemplo n.º 15
0
    def fit(self, data_instances):
        LOGGER.info("parameters: alpha: {}, eps: {}, max_iter: {}"
                    "batch_size: {}".format(self.alpha, self.eps,
                                            self.max_iter, self.batch_size))
        self.__init_parameters()

        w = self.__init_model(data_instances)

        mini_batch_obj = MiniBatch(data_inst=data_instances,
                                   batch_size=self.batch_size)
        for iter_num in range(self.max_iter):
            # mini-batch
            # LOGGER.debug("Enter iter_num: {}".format(iter_num))
            batch_data_generator = mini_batch_obj.mini_batch_data_generator()
            total_loss = 0
            batch_num = 0
            for batch_data in batch_data_generator:
                f = functools.partial(self.gradient_operator.compute,
                                      coef=self.coef_,
                                      intercept=self.intercept_,
                                      fit_intercept=self.fit_intercept)
                grad_loss = batch_data.mapPartitions(f)
                n = grad_loss.count()
                grad, loss = grad_loss.reduce(
                    self.aggregator.aggregate_grad_loss)
                grad /= n
                loss /= n

                if self.updater is not None:
                    loss_norm = self.updater.loss_norm(self.coef_)
                    total_loss += (loss + loss_norm)
                # LOGGER.debug("before update: {}".format(grad))
                delta_grad = self.optimizer.apply_gradients(grad)
                # LOGGER.debug("after apply: {}".format(delta_grad))

                self.update_model(delta_grad)
                batch_num += 1

            total_loss /= batch_num
            w = self.merge_model()
            LOGGER.info("iter: {}, loss: {}".format(iter_num, total_loss))
            # send model
            model_transfer_id = self.transfer_variable.generate_transferid(
                self.transfer_variable.guest_model, iter_num)
            federation.remote(w,
                              name=self.transfer_variable.guest_model.name,
                              tag=model_transfer_id,
                              role=consts.ARBITER,
                              idx=0)
            # send loss
            loss_transfer_id = self.transfer_variable.generate_transferid(
                self.transfer_variable.guest_loss, iter_num)
            federation.remote(total_loss,
                              name=self.transfer_variable.guest_loss.name,
                              tag=loss_transfer_id,
                              role=consts.ARBITER,
                              idx=0)

            # recv model
            model_transfer_id = self.transfer_variable.generate_transferid(
                self.transfer_variable.final_model, iter_num)

            w = federation.get(name=self.transfer_variable.final_model.name,
                               tag=model_transfer_id,
                               idx=0)
            w = np.array(w)
            # LOGGER.debug("Received final model: {}".format(w))
            self.set_coef_(w)

            # recv converge flag
            converge_flag_id = self.transfer_variable.generate_transferid(
                self.transfer_variable.converge_flag, iter_num)
            converge_flag = federation.get(
                name=self.transfer_variable.converge_flag.name,
                tag=converge_flag_id,
                idx=0)
            self.n_iter_ = iter_num
            LOGGER.debug("converge flag is :{}".format(converge_flag))

            if converge_flag:
                # self.save_model(w)
                break
Exemplo n.º 16
0
    def fit(self, data_instances):
        self._abnormal_detection(data_instances)

        self.header = data_instances.schema.get(
            'header')  # ['x1', 'x2', 'x3' ... ]

        self.__init_parameters()

        self.__init_model(data_instances)

        mini_batch_obj = MiniBatch(data_inst=data_instances,
                                   batch_size=self.batch_size)

        for iter_num in range(self.max_iter):
            # mini-batch
            batch_data_generator = mini_batch_obj.mini_batch_data_generator()
            total_loss = 0
            batch_num = 0

            for batch_data in batch_data_generator:
                n = batch_data.count()

                f = functools.partial(self.gradient_operator.compute,
                                      coef=self.coef_,
                                      intercept=self.intercept_,
                                      fit_intercept=self.fit_intercept)
                grad_loss = batch_data.mapPartitions(f)

                grad, loss = grad_loss.reduce(
                    self.aggregator.aggregate_grad_loss)

                grad /= n
                loss /= n

                if self.updater is not None:
                    loss_norm = self.updater.loss_norm(self.coef_)
                    total_loss += (loss + loss_norm)
                delta_grad = self.optimizer.apply_gradients(grad)

                self.update_model(delta_grad)
                batch_num += 1

            total_loss /= batch_num
            w = self.merge_model()
            self.loss_history.append(total_loss)
            LOGGER.info("iter: {}, loss: {}".format(iter_num, total_loss))
            # send model
            model_transfer_id = self.transfer_variable.generate_transferid(
                self.transfer_variable.guest_model, iter_num)
            federation.remote(w,
                              name=self.transfer_variable.guest_model.name,
                              tag=model_transfer_id,
                              role=consts.ARBITER,
                              idx=0)

            # send loss

            loss_transfer_id = self.transfer_variable.generate_transferid(
                self.transfer_variable.guest_loss, iter_num)
            federation.remote(total_loss,
                              name=self.transfer_variable.guest_loss.name,
                              tag=loss_transfer_id,
                              role=consts.ARBITER,
                              idx=0)

            # recv model
            model_transfer_id = self.transfer_variable.generate_transferid(
                self.transfer_variable.final_model, iter_num)
            w = federation.get(name=self.transfer_variable.final_model.name,
                               tag=model_transfer_id,
                               idx=0)

            w = np.array(w)
            self.set_coef_(w)

            # recv converge flag
            converge_flag_id = self.transfer_variable.generate_transferid(
                self.transfer_variable.converge_flag, iter_num)
            converge_flag = federation.get(
                name=self.transfer_variable.converge_flag.name,
                tag=converge_flag_id,
                idx=0)

            self.n_iter_ = iter_num
            LOGGER.debug("converge flag is :{}".format(converge_flag))

            if converge_flag:
                self.is_converged = True
                break

        self.show_meta()
        self.show_model()
        LOGGER.debug("in fit self coef: {}".format(self.coef_))
        return data_instances
Exemplo n.º 17
0
    def fit(self, data_instances, validate_data=None):

        self._abnormal_detection(data_instances)
        self.init_schema(data_instances)

        validation_strategy = self.init_validation_strategy(
            data_instances, validate_data)
        self.model_weights = self._init_model_variables(data_instances)

        max_iter = self.max_iter
        total_data_num = data_instances.count()
        mini_batch_obj = MiniBatch(data_inst=data_instances,
                                   batch_size=self.batch_size)
        model_weights = self.model_weights

        self.__synchronize_encryption()
        self.zcl_idx, self.zcl_num_party = self.transfer_variable.num_party.get(
            idx=0, suffix=('train', ))
        LOGGER.debug("party num:" + str(self.zcl_num_party))
        self.__init_model()

        self.train_loss_results = []
        self.train_accuracy_results = []
        self.test_loss_results = []
        self.test_accuracy_results = []

        batch_num = 0

        for iter_num in range(self.max_iter):
            total_loss = 0
            # batch_num = 0
            iter_num_ = 0
            epoch_train_loss_avg = tfe.metrics.Mean()
            epoch_train_accuracy = tfe.metrics.Accuracy()

            for train_x, train_y in self.zcl_dataset:
                LOGGER.info("Staring batch {}".format(batch_num))
                start_t = time.time()
                loss_value, grads = self.__grad(self.zcl_model, train_x,
                                                train_y)
                loss_value = loss_value.numpy()
                grads = [x.numpy() for x in grads]

                sizes = [layer.size * self.zcl_num_party for layer in grads]
                guest_max = [np.max(layer) for layer in grads]
                guest_min = [np.min(layer) for layer in grads]

                # clipping_thresholds_guest = batch_encryption.calculate_clip_threshold_aciq_l(grads, bit_width=self.bit_width)
                grad_max_all = self.transfer_variable.host_grad_max.get(
                    idx=-1, suffix=(iter_num_, batch_num))
                grad_min_all = self.transfer_variable.host_grad_min.get(
                    idx=-1, suffix=(iter_num_, batch_num))
                grad_max_all.append(guest_max)
                grad_min_all.append(guest_min)
                max_v = []
                min_v = []
                for layer_idx in range(len(grads)):
                    max_v.append(
                        [np.max([party[layer_idx] for party in grad_max_all])])
                    min_v.append(
                        [np.min([party[layer_idx] for party in grad_min_all])])
                grads_max_min = np.concatenate(
                    [np.array(max_v), np.array(min_v)], axis=1)
                clipping_thresholds = batch_encryption.calculate_clip_threshold_aciq_g(
                    grads_max_min, sizes, bit_width=self.bit_width)
                LOGGER.info("clipping threshold " + str(clipping_thresholds))

                r_maxs = [x * self.zcl_num_party for x in clipping_thresholds]
                self.transfer_variable.clipping_threshold.remote(
                    obj=clipping_thresholds,
                    role=consts.HOST,
                    idx=-1,
                    suffix=(iter_num_, batch_num))
                grads = batch_encryption.clip_with_threshold(
                    grads, clipping_thresholds)

                LOGGER.info("Start batch encrypting")
                loss_value = batch_encryption.encrypt(
                    self.zcl_encrypt_operator.get_public_key(), loss_value)
                # grads = [batch_encryption.encrypt_matrix(self.zcl_encrypt_operator.get_public_key(), x) for x in grads]
                enc_grads, og_shape = batch_encryption.batch_enc_per_layer(
                    publickey=self.zcl_encrypt_operator.get_public_key(),
                    party=grads,
                    r_maxs=r_maxs,
                    bit_width=self.bit_width,
                    batch_size=self.e_batch_size)
                # grads = Gradients(enc_grads)
                LOGGER.info("Finish encrypting")
                # grads = self.encrypt_operator.get_public_key()
                # self.transfer_variable.guest_grad.remote(obj=grads.for_remote(), role=consts.ARBITER, idx=0,
                #                                          suffix=(iter_num_, batch_num))
                self.transfer_variable.guest_grad.remote(obj=enc_grads,
                                                         role=consts.ARBITER,
                                                         idx=0,
                                                         suffix=(iter_num_,
                                                                 batch_num))
                LOGGER.info("Sent grads")
                self.transfer_variable.guest_loss.remote(obj=loss_value,
                                                         role=consts.ARBITER,
                                                         idx=0,
                                                         suffix=(iter_num_,
                                                                 batch_num))
                LOGGER.info("Sent loss")

                sum_grads = self.transfer_variable.aggregated_grad.get(
                    idx=0, suffix=(iter_num_, batch_num))
                LOGGER.info("Got grads")
                sum_loss = self.transfer_variable.aggregated_loss.get(
                    idx=0, suffix=(iter_num_, batch_num))
                LOGGER.info("Got loss")

                sum_loss = batch_encryption.decrypt(
                    self.zcl_encrypt_operator.get_privacy_key(), sum_loss)
                # sum_grads = [
                #     batch_encryption.decrypt_matrix(self.zcl_encrypt_operator.get_privacy_key(), x).astype(np.float32) for x
                #     in sum_grads.unboxed]
                sum_grads = batch_encryption.batch_dec_per_layer(
                    privatekey=self.zcl_encrypt_operator.get_privacy_key(),
                    # party=sum_grads.unboxed, og_shapes=og_shape,
                    party=sum_grads,
                    og_shapes=og_shape,
                    r_maxs=r_maxs,
                    bit_width=self.bit_width,
                    batch_size=self.e_batch_size)
                LOGGER.info("Finish decrypting")

                # sum_grads = np.array(sum_grads) / self.zcl_num_party

                self.zcl_optimizer.apply_gradients(
                    zip(sum_grads, self.zcl_model.trainable_variables),
                    self.zcl_global_step)

                elapsed_time = time.time() - start_t
                # epoch_train_loss_avg(loss_value)
                # epoch_train_accuracy(tf.argmax(self.zcl_model(train_x), axis=1, output_type=tf.int32),
                #                      train_y)
                self.train_loss_results.append(sum_loss)
                train_accuracy_v = accuracy_score(
                    train_y,
                    tf.argmax(self.zcl_model(train_x),
                              axis=1,
                              output_type=tf.int32))
                self.train_accuracy_results.append(train_accuracy_v)
                test_loss_v = self.__loss(self.zcl_model, self.zcl_x_test,
                                          self.zcl_y_test)
                self.test_loss_results.append(test_loss_v)
                test_accuracy_v = accuracy_score(
                    self.zcl_y_test,
                    tf.argmax(self.zcl_model(self.zcl_x_test),
                              axis=1,
                              output_type=tf.int32))
                self.test_accuracy_results.append(test_accuracy_v)

                LOGGER.info(
                    "Epoch {:03d}, iteration {:03d}: train_loss: {:.3f}, train_accuracy: {:.3%}, test_loss: {:.3f}, "
                    "test_accuracy: {:.3%}, elapsed_time: {:.4f}".format(
                        iter_num, batch_num, sum_loss, train_accuracy_v,
                        test_loss_v, test_accuracy_v, elapsed_time))

                batch_num += 1

                # if batch_num >= self.zcl_early_stop_batch:
                #     return

            self.n_iter_ = iter_num
Exemplo n.º 18
0
class Guest(batch_info_sync.Guest):
    def __init__(self):
        self.mini_batch_obj = None
        self.finish_sycn = False
        self.batch_nums = None
        self.batch_masked = False

    def register_batch_generator(self, transfer_variables, has_arbiter=True):
        self._register_batch_data_index_transfer(transfer_variables.batch_info,
                                                 transfer_variables.batch_data_index,
                                                 getattr(transfer_variables, "batch_validate_info", None),
                                                 has_arbiter)

    def initialize_batch_generator(self, data_instances, batch_size, suffix=tuple(),
                                   shuffle=False, batch_strategy="full", masked_rate=0):
        self.mini_batch_obj = MiniBatch(data_instances, batch_size=batch_size, shuffle=shuffle,
                                        batch_strategy=batch_strategy, masked_rate=masked_rate)
        self.batch_nums = self.mini_batch_obj.batch_nums
        self.batch_masked = self.mini_batch_obj.batch_size != self.mini_batch_obj.masked_batch_size
        batch_info = {"batch_size": self.mini_batch_obj.batch_size, "batch_num": self.batch_nums,
                      "batch_mutable": self.mini_batch_obj.batch_mutable,
                      "masked_batch_size": self.mini_batch_obj.masked_batch_size}
        self.sync_batch_info(batch_info, suffix)

        if not self.mini_batch_obj.batch_mutable:
            self.prepare_batch_data(suffix)

    def prepare_batch_data(self, suffix=tuple()):
        self.mini_batch_obj.generate_batch_data()
        index_generator = self.mini_batch_obj.mini_batch_data_generator(result='index')
        batch_index = 0
        for batch_data_index in index_generator:
            batch_suffix = suffix + (batch_index,)
            self.sync_batch_index(batch_data_index, batch_suffix)
            batch_index += 1

    def generate_batch_data(self, with_index=False, suffix=tuple()):
        if self.mini_batch_obj.batch_mutable:
            self.prepare_batch_data(suffix)

        if with_index:
            data_generator = self.mini_batch_obj.mini_batch_data_generator(result='both')
            for batch_data, index_data in data_generator:
                yield batch_data, index_data
        else:
            data_generator = self.mini_batch_obj.mini_batch_data_generator(result='data')
            for batch_data in data_generator:
                yield batch_data

    def verify_batch_legality(self, suffix=tuple()):
        validate_infos = self.sync_batch_validate_info(suffix)
        least_batch_size = 0
        is_legal = True
        for validate_info in validate_infos:
            legality = validate_info.get("legality")
            if not legality:
                is_legal = False
                least_batch_size = max(least_batch_size, validate_info.get("least_batch_size"))

        if not is_legal:
            raise ValueError(f"To use batch masked strategy, "
                             f"(masked_rate + 1) * batch_size should > {least_batch_size}")
Exemplo n.º 19
0
    def fit(self, data_instances, validate_data=None):

        self._abnormal_detection(data_instances)
        self.init_schema(data_instances)

        validation_strategy = self.init_validation_strategy(
            data_instances, validate_data)
        self.model_weights = self._init_model_variables(data_instances)

        max_iter = self.max_iter
        # total_data_num = data_instances.count()
        mini_batch_obj = MiniBatch(data_inst=data_instances,
                                   batch_size=self.batch_size)
        model_weights = self.model_weights

        degree = 0
        while self.n_iter_ < max_iter:
            batch_data_generator = mini_batch_obj.mini_batch_data_generator()

            self.optimizer.set_iters(self.n_iter_)
            if self.n_iter_ > 0 and self.n_iter_ % self.aggregate_iters == 0:
                weight = self.aggregator.aggregate_then_get(
                    model_weights, degree=degree, suffix=self.n_iter_)
                LOGGER.debug(
                    "Before aggregate: {}, degree: {} after aggregated: {}".
                    format(model_weights.unboxed / degree, degree,
                           weight.unboxed))

                self.model_weights = LogisticRegressionWeights(
                    weight.unboxed, self.fit_intercept)
                loss = self._compute_loss(data_instances)
                self.aggregator.send_loss(loss,
                                          degree=degree,
                                          suffix=(self.n_iter_, ))
                degree = 0

                self.is_converged = self.aggregator.get_converge_status(
                    suffix=(self.n_iter_, ))
                LOGGER.info(
                    "n_iters: {}, loss: {} converge flag is :{}".format(
                        self.n_iter_, loss, self.is_converged))
                if self.is_converged:
                    break
                model_weights = self.model_weights

            batch_num = 0
            for batch_data in batch_data_generator:
                n = batch_data.count()
                LOGGER.debug(
                    "In each batch, lr_weight: {}, batch_data count: {}".
                    format(model_weights.unboxed, n))
                f = functools.partial(self.gradient_operator.compute_gradient,
                                      coef=model_weights.coef_,
                                      intercept=model_weights.intercept_,
                                      fit_intercept=self.fit_intercept)
                grad = batch_data.mapPartitions(f).reduce(
                    fate_operator.reduce_add)
                grad /= n
                LOGGER.debug(
                    'iter: {}, batch_index: {}, grad: {}, n: {}'.format(
                        self.n_iter_, batch_num, grad, n))
                model_weights = self.optimizer.update_model(model_weights,
                                                            grad,
                                                            has_applied=False)
                batch_num += 1
                degree += n

            validation_strategy.validate(self, self.n_iter_)
            self.n_iter_ += 1
Exemplo n.º 20
0
    def fit(self, data_instances):
        if not self.need_run:
            return data_instances

        self._abnormal_detection(data_instances)
        self.init_schema(data_instances)
        self.__init_parameters()

        self.__init_model(data_instances)

        mini_batch_obj = MiniBatch(data_inst=data_instances,
                                   batch_size=self.batch_size)

        for iter_num in range(self.max_iter):
            # mini-batch
            batch_data_generator = mini_batch_obj.mini_batch_data_generator()
            total_loss = 0
            batch_num = 0

            for batch_data in batch_data_generator:
                n = batch_data.count()

                f = functools.partial(self.gradient_operator.compute,
                                      coef=self.coef_,
                                      intercept=self.intercept_,
                                      fit_intercept=self.fit_intercept)
                grad_loss = batch_data.mapPartitions(f)

                grad, loss = grad_loss.reduce(
                    self.aggregator.aggregate_grad_loss)

                grad /= n
                loss /= n

                if self.updater is not None:
                    loss_norm = self.updater.loss_norm(self.coef_)
                    total_loss += (loss + loss_norm)
                delta_grad = self.optimizer.apply_gradients(grad)

                self.update_model(delta_grad)
                batch_num += 1

            total_loss /= batch_num

            # if not self.use_loss:
            #     total_loss = np.linalg.norm(self.coef_)

            w = self.merge_model()
            if not self.need_one_vs_rest:
                metric_meta = MetricMeta(name='train',
                                         metric_type="LOSS",
                                         extra_metas={
                                             "unit_name": "iters",
                                         })
                # metric_name = self.get_metric_name('loss')

                self.callback_meta(metric_name='loss',
                                   metric_namespace='train',
                                   metric_meta=metric_meta)
                self.callback_metric(
                    metric_name='loss',
                    metric_namespace='train',
                    metric_data=[Metric(iter_num, total_loss)])

            self.loss_history.append(total_loss)
            LOGGER.info("iter: {}, loss: {}".format(iter_num, total_loss))
            # send model
            model_transfer_id = self.transfer_variable.generate_transferid(
                self.transfer_variable.guest_model, iter_num)
            LOGGER.debug("Start to remote model: {}, transfer_id: {}".format(
                w, model_transfer_id))

            federation.remote(w,
                              name=self.transfer_variable.guest_model.name,
                              tag=model_transfer_id,
                              role=consts.ARBITER,
                              idx=0)

            # send loss
            # if self.use_loss:
            loss_transfer_id = self.transfer_variable.generate_transferid(
                self.transfer_variable.guest_loss, iter_num)
            LOGGER.debug(
                "Start to remote total_loss: {}, transfer_id: {}".format(
                    total_loss, loss_transfer_id))
            federation.remote(total_loss,
                              name=self.transfer_variable.guest_loss.name,
                              tag=loss_transfer_id,
                              role=consts.ARBITER,
                              idx=0)

            # recv model
            model_transfer_id = self.transfer_variable.generate_transferid(
                self.transfer_variable.final_model, iter_num)
            w = federation.get(name=self.transfer_variable.final_model.name,
                               tag=model_transfer_id,
                               idx=0)

            w = np.array(w)
            self.set_coef_(w)

            # recv converge flag
            converge_flag_id = self.transfer_variable.generate_transferid(
                self.transfer_variable.converge_flag, iter_num)
            converge_flag = federation.get(
                name=self.transfer_variable.converge_flag.name,
                tag=converge_flag_id,
                idx=0)

            self.n_iter_ = iter_num
            LOGGER.debug("converge flag is :{}".format(converge_flag))

            if converge_flag:
                self.is_converged = True
                break
Exemplo n.º 21
0
    def fit(self, data_instances, validate_data=None):
        LOGGER.debug("Start data count: {}".format(data_instances.count()))

        self._abnormal_detection(data_instances)
        self.init_schema(data_instances)
        # validation_strategy = self.init_validation_strategy(data_instances, validate_data)

        pubkey = self.cipher.gen_paillier_pubkey(enable=self.use_encrypt,
                                                 suffix=('fit', ))
        if self.use_encrypt:
            self.cipher_operator.set_public_key(pubkey)

        self.model_weights = self._init_model_variables(data_instances)
        w = self.cipher_operator.encrypt_list(self.model_weights.unboxed)
        self.model_weights = LogisticRegressionWeights(
            w, self.model_weights.fit_intercept)

        LOGGER.debug("After init, model_weights: {}".format(
            self.model_weights.unboxed))

        mini_batch_obj = MiniBatch(data_inst=data_instances,
                                   batch_size=self.batch_size)

        total_batch_num = mini_batch_obj.batch_nums

        if self.use_encrypt:
            re_encrypt_times = (total_batch_num -
                                1) // self.re_encrypt_batches + 1
            LOGGER.debug(
                "re_encrypt_times is :{}, batch_size: {}, total_batch_num: {}, re_encrypt_batches: {}"
                .format(re_encrypt_times, self.batch_size, total_batch_num,
                        self.re_encrypt_batches))
            self.cipher.set_re_cipher_time(re_encrypt_times)

        total_data_num = data_instances.count()
        LOGGER.debug("Current data count: {}".format(total_data_num))

        model_weights = self.model_weights
        degree = 0
        while self.n_iter_ < self.max_iter + 1:
            batch_data_generator = mini_batch_obj.mini_batch_data_generator()

            if ((self.n_iter_ + 1) % self.aggregate_iters
                    == 0) or self.n_iter_ == self.max_iter:
                weight = self.aggregator.aggregate_then_get(
                    model_weights, degree=degree, suffix=self.n_iter_)
                # LOGGER.debug("Before aggregate: {}, degree: {} after aggregated: {}".format(
                #     model_weights.unboxed / degree,
                #     degree,
                #     weight.unboxed))
                self.model_weights = LogisticRegressionWeights(
                    weight.unboxed, self.fit_intercept)
                if not self.use_encrypt:
                    loss = self._compute_loss(data_instances)
                    self.aggregator.send_loss(loss,
                                              degree=degree,
                                              suffix=(self.n_iter_, ))
                    LOGGER.info("n_iters: {}, loss: {}".format(
                        self.n_iter_, loss))
                degree = 0
                self.is_converged = self.aggregator.get_converge_status(
                    suffix=(self.n_iter_, ))
                LOGGER.info("n_iters: {}, is_converge: {}".format(
                    self.n_iter_, self.is_converged))
                if self.is_converged or self.n_iter_ == self.max_iter:
                    break
                model_weights = self.model_weights

            batch_num = 0
            for batch_data in batch_data_generator:
                n = batch_data.count()
                degree += n
                LOGGER.debug('before compute_gradient')
                f = functools.partial(self.gradient_operator.compute_gradient,
                                      coef=model_weights.coef_,
                                      intercept=model_weights.intercept_,
                                      fit_intercept=self.fit_intercept)
                grad = batch_data.mapPartitions(f).reduce(
                    fate_operator.reduce_add)
                grad /= n
                model_weights = self.optimizer.update_model(model_weights,
                                                            grad,
                                                            has_applied=False)

                if self.use_encrypt and batch_num % self.re_encrypt_batches == 0:
                    LOGGER.debug(
                        "Before accept re_encrypted_model, batch_iter_num: {}".
                        format(batch_num))
                    w = self.cipher.re_cipher(w=model_weights.unboxed,
                                              iter_num=self.n_iter_,
                                              batch_iter_num=batch_num)
                    model_weights = LogisticRegressionWeights(
                        w, self.fit_intercept)
                batch_num += 1

            # validation_strategy.validate(self, self.n_iter_)
            self.n_iter_ += 1

        LOGGER.info("Finish Training task, total iters: {}".format(
            self.n_iter_))
Exemplo n.º 22
0
    def fit(self,
            data_instances,
            node2id,
            local_instances=None,
            common_nodes=None):
        """
        Train ne model pf role host
        Parameters
        ----------
        data_instances: Dtable of anchor node, input data
        """
        LOGGER.info("Enter hetero_ne host")
        self.n_node = len(node2id)
        LOGGER.info("Host party has {} nodes".format(self.n_node))

        data_instances = data_instances.mapValues(HeteroNEHost.load_data)
        LOGGER.info("Transform input data to train instance")

        public_key = federation.get(
            name=self.transfer_variable.paillier_pubkey.name,
            tag=self.transfer_variable.generate_transferid(
                self.transfer_variable.paillier_pubkey),
            idx=0)
        LOGGER.info("Get Publick key from arbiter:{}".format(public_key))
        self.encrypt_operator.set_public_key(public_key)

        ##############
        # horizontal federated learning
        LOGGER.info("Generate mini-batch for local instances in guest")
        mini_batch_obj_local = MiniBatch(local_instances,
                                         batch_size=self.batch_size)
        common_node_instances = eggroll.parallelize(
            ((node, node) for node in common_nodes),
            include_key=True,
            name='common_nodes')
        ##############

        batch_info = federation.get(
            name=self.transfer_variable.batch_info.name,
            tag=self.transfer_variable.generate_transferid(
                self.transfer_variable.batch_info),
            idx=0)
        LOGGER.info("Get batch_info from guest: {}".format(batch_info))

        self.batch_size = batch_info['batch_size']
        self.batch_num = batch_info['batch_num']
        if self.batch_size < consts.MIN_BATCH_SIZE and self.batch_size != -1:
            raise ValueError(
                "Batch size get from guest should not less than 10, except -1, batch_size is {}"
                .format(self.batch_size))

        self.encrypted_calculator = [
            EncryptModeCalculator(
                self.encrypt_operator,
                self.encrypted_mode_calculator_param.mode,
                self.encrypted_mode_calculator_param.re_encrypted_rate)
            for _ in range(self.batch_num)
        ]

        LOGGER.info("Start initilize model.")
        self.embedding_ = self.initializer.init_model((self.n_node, self.dim),
                                                      self.init_param_obj)

        self.n_iter_ = 0
        index_data_inst_map = {}

        while self.n_iter_ < self.max_iter:
            LOGGER.info("iter: {}".format(self.n_iter_))

            #################
            local_batch_data_generator = mini_batch_obj_local.mini_batch_data_generator(
            )
            total_loss = 0
            local_batch_num = 0
            LOGGER.info("Horizontally learning")
            for local_batch_data in local_batch_data_generator:
                n = local_batch_data.count()
                LOGGER.info("Local batch data count:{}".format(n))
                E_Y = self.compute_local_embedding(local_batch_data,
                                                   self.embedding_, node2id)
                local_grads_e1, local_grads_e2, local_loss = self.local_gradient_operator.compute(
                    E_Y, 'E_1')
                local_grads_e1 = local_grads_e1.mapValues(
                    lambda g: self.local_optimizer.apply_gradients(g / n))
                local_grads_e2 = local_grads_e2.mapValues(
                    lambda g: self.local_optimizer.apply_gradients(g / n))
                e1id_join_grads = local_batch_data.join(
                    local_grads_e1, lambda v, g: (node2id[v[0]], g))
                e2id_join_grads = local_batch_data.join(
                    local_grads_e2, lambda v, g: (node2id[v[1]], g))
                self.update_model(e1id_join_grads)
                self.update_model(e2id_join_grads)

                local_loss = local_loss / n
                local_batch_num += 1
                total_loss += local_loss
                LOGGER.info("gradient count:{}".format(
                    e1id_join_grads.count()))

            host_common_embedding = common_node_instances.mapValues(
                lambda node: self.embedding_[node2id[node]])
            federation.remote(
                host_common_embedding,
                name=self.transfer_variable.host_common_embedding.name,
                tag=self.transfer_variable.generate_transferid(
                    self.transfer_variable.host_common_embedding, self.n_iter_,
                    0),
                role=consts.ARBITER,
                idx=0)

            common_embedding = federation.get(
                name=self.transfer_variable.common_embedding.name,
                tag=self.transfer_variable.generate_transferid(
                    self.transfer_variable.common_embedding, self.n_iter_, 0),
                idx=0)

            self.update_common_nodes(common_embedding, common_nodes, node2id)

            total_loss /= local_batch_num
            LOGGER.info("Iter {}, Local loss: {}".format(
                self.n_iter_, total_loss))

            batch_index = 0
            while batch_index < self.batch_num:
                LOGGER.info("batch:{}".format(batch_index))

                # set batch_data
                # in order to avoid communicating in next iteration
                # in next iteration, the sequence of batches is the same
                if len(self.batch_index_list) < self.batch_num:
                    batch_data_index = federation.get(
                        name=self.transfer_variable.batch_data_index.name,
                        tag=self.transfer_variable.generate_transferid(
                            self.transfer_variable.batch_data_index,
                            self.n_iter_, batch_index),
                        idx=0)
                    LOGGER.info("Get batch_index from Guest")
                    self.batch_index_list.append(batch_index)
                else:
                    batch_data_index = self.batch_index_list[batch_index]

                # Get mini-batch train_data
                # in order to avoid joining for next iteration
                if len(index_data_inst_map) < self.batch_num:
                    batch_data_inst = batch_data_index.join(
                        data_instances, lambda g, d: d)
                    index_data_inst_map[batch_index] = batch_data_inst
                else:
                    batch_data_inst = index_data_inst_map[batch_data_index]

                LOGGER.info("batch_data_inst size:{}".format(
                    batch_data_inst.count()))

                #self.transform(data_inst)

                # compute forward
                host_forward = self.compute_forward(batch_data_inst,
                                                    self.embedding_, node2id,
                                                    batch_index)
                federation.remote(
                    host_forward,
                    name=self.transfer_variable.host_forward_dict.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.host_forward_dict, self.n_iter_,
                        batch_index),
                    role=consts.GUEST,
                    idx=0)
                LOGGER.info("Remote host_forward to guest")

                # Get optimize host gradient and update model
                optim_host_gradient = federation.get(
                    name=self.transfer_variable.host_optim_gradient.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.host_optim_gradient,
                        self.n_iter_, batch_index),
                    idx=0)
                LOGGER.info("Get optim_host_gradient from arbiter")

                nodeid_join_gradient = batch_data_inst.join(
                    optim_host_gradient, lambda instance, gradient:
                    (node2id[instance.features], gradient))
                LOGGER.info("update_model")
                self.update_model(nodeid_join_gradient)

                # update local model
                #training_info = {"iteration": self.n_iter_, "batch_index": batch_index}
                #self.update_local_model(fore_gradient, batch_data_inst, self.coef_, **training_info)

                batch_index += 1

                rubbish_list = [host_forward]
                rubbish_clear(rubbish_list)

            #######
            host_common_embedding = common_node_instances.mapValues(
                lambda node: self.embedding_[node2id[node]])
            federation.remote(
                host_common_embedding,
                name=self.transfer_variable.host_common_embedding.name,
                tag=self.transfer_variable.generate_transferid(
                    self.transfer_variable.host_common_embedding, self.n_iter_,
                    1),
                role=consts.ARBITER,
                idx=0)

            common_embedding = federation.get(
                name=self.transfer_variable.common_embedding.name,
                tag=self.transfer_variable.generate_transferid(
                    self.transfer_variable.common_embedding, self.n_iter_, 1),
                idx=0)

            self.update_common_nodes(common_embedding, common_nodes, node2id)
            #######

            is_stopped = federation.get(
                name=self.transfer_variable.is_stopped.name,
                tag=self.transfer_variable.generate_transferid(
                    self.transfer_variable.is_stopped,
                    self.n_iter_,
                ),
                idx=0)
            LOGGER.info("Get is_stop flag from arbiter:{}".format(is_stopped))

            self.n_iter_ += 1
            if is_stopped:
                break

        LOGGER.info("Reach max iter {}, train mode finish!".format(
            self.max_iter))
        embedding_table = eggroll.table(name='host',
                                        namespace='node_embedding',
                                        partition=10)
        id2node = dict(zip(node2id.values(), node2id.keys()))
        for id, embedding in enumerate(self.embedding_):
            embedding_table.put(id2node[id], embedding)
        embedding_table.save_as(name='host',
                                namespace='node_embedding',
                                partition=10)
        LOGGER.info("Reach max iter {}, train model finish!".format(
            self.max_iter))