Пример #1
0
 def check_host_number(self, tree_type):
     host_num = len(self.component_properties.host_party_idlist)
     LOGGER.info('host number is {}'.format(host_num))
     if tree_type == plan.tree_type_dict['layered_tree']:
         assert host_num == 1, 'only 1 host party is allowed in layered mode'
Пример #2
0
 def remove_redundant_splitinfo_in_split_maskdict(self, split_nid_used):
     LOGGER.info("remove duplicated nodes from split mask dict")
     duplicated_nodes = set(
         self.split_maskdict.keys()) - set(split_nid_used)
     for nid in duplicated_nodes:
         del self.split_maskdict[nid]
Пример #3
0
    def fit(self):
        """
        start to fit
        """
        LOGGER.info(
            'begin to fit h**o decision tree, epoch {}, tree idx {}'.format(
                self.epoch_idx, self.tree_idx))

        # compute local g_sum and h_sum
        g_sum, h_sum = self.get_grad_hess_sum(self.g_h)

        # get aggregated root info
        self.aggregator.send_local_root_node_info(g_sum,
                                                  h_sum,
                                                  suffix=('root_node_sync1',
                                                          self.epoch_idx))
        g_h_dict = self.aggregator.get_aggregated_root_info(
            suffix=('root_node_sync2', self.epoch_idx))
        global_g_sum, global_h_sum = g_h_dict['g_sum'], g_h_dict['h_sum']

        # initialize node
        root_node = Node(id=0,
                         sitename=consts.GUEST,
                         sum_grad=global_g_sum,
                         sum_hess=global_h_sum,
                         weight=self.splitter.node_weight(
                             global_g_sum, global_h_sum))

        self.cur_layer_node = [root_node]
        LOGGER.debug('assign samples to root node')
        self.inst2node_idx = self.assign_instance_to_root_node(
            self.data_bin, 0)

        tree_height = self.max_depth + 1  # non-leaf node height + 1 layer leaf

        for dep in range(tree_height):

            if dep + 1 == tree_height:

                for node in self.cur_layer_node:
                    node.is_leaf = True
                    self.tree_node.append(node)

                rest_sample_weights = self.get_node_sample_weights(
                    self.inst2node_idx, self.tree_node)
                if self.sample_weights is None:
                    self.sample_weights = rest_sample_weights
                else:
                    self.sample_weights = self.sample_weights.union(
                        rest_sample_weights)

                # stop fitting
                break

            LOGGER.debug('start to fit layer {}'.format(dep))

            table_with_assignment = self.update_instances_node_positions()

            # send current layer node number:
            self.sync_cur_layer_node_num(len(self.cur_layer_node),
                                         suffix=(dep, self.epoch_idx,
                                                 self.tree_idx))

            split_info, agg_histograms = [], []
            for batch_id, i in enumerate(
                    range(0, len(self.cur_layer_node), self.max_split_nodes)):
                cur_to_split = self.cur_layer_node[i:i + self.max_split_nodes]

                node_map = self.get_node_map(nodes=cur_to_split)
                LOGGER.debug('node map is {}'.format(node_map))
                LOGGER.debug(
                    'computing histogram for batch{} at depth{}'.format(
                        batch_id, dep))
                local_histogram = self.get_left_node_local_histogram(
                    cur_nodes=cur_to_split,
                    tree=self.tree_node,
                    g_h=self.g_h,
                    table_with_assign=table_with_assignment,
                    split_points=self.bin_split_points,
                    sparse_point=self.bin_sparse_points,
                    valid_feature=self.valid_features)

                LOGGER.debug(
                    'federated finding best splits for batch{} at layer {}'.
                    format(batch_id, dep))
                self.sync_local_node_histogram(local_histogram,
                                               suffix=(batch_id, dep,
                                                       self.epoch_idx,
                                                       self.tree_idx))

                agg_histograms += local_histogram

            split_info = self.sync_best_splits(suffix=(dep, self.epoch_idx))
            LOGGER.debug('got best splits from arbiter')

            new_layer_node = self.update_tree(self.cur_layer_node, split_info)
            self.cur_layer_node = new_layer_node

            self.inst2node_idx, leaf_val = self.assign_instances_to_new_node(
                table_with_assignment, self.tree_node)

            # record leaf val
            if self.sample_weights is None:
                self.sample_weights = leaf_val
            else:
                self.sample_weights = self.sample_weights.union(leaf_val)

            LOGGER.debug('assigning instance to new nodes done')

        self.convert_bin_to_real()
        LOGGER.debug('fitting tree done')
Пример #4
0
 def sync_dispatch_node_host(self, dep):
     LOGGER.info("get node from host to dispath, depth is {}".format(dep))
     dispatch_node_host = self.transfer_inst.dispatch_node_host.get(
         idx=0, suffix=(dep, ))
     return dispatch_node_host
Пример #5
0
 def sync_predict_finish_tag(self, recv_times):
     LOGGER.info(
         "get the {}-th predict finish tag from guest".format(recv_times))
     finish_tag = self.transfer_inst.predict_finish_tag.get(
         idx=0, suffix=(recv_times, ))
     return finish_tag
Пример #6
0
 def set_encrypter(self, encrypter):
     LOGGER.info("set encrypter")
     self.encrypter = encrypter
Пример #7
0
 def sync_node_positions(self, dep=-1):
     LOGGER.info("get tree node queue of depth {}".format(dep))
     node_positions = self.transfer_inst.node_positions.get(idx=0,
                                                            suffix=(dep, ))
     return node_positions
Пример #8
0
    def fit_binary(self, data_instances, validate_data=None):
        LOGGER.info("Enter hetero_lr_guest fit")
        self.header = self.get_header(data_instances)

        self.validation_strategy = self.init_validation_strategy(
            data_instances, validate_data)
        data_instances = data_instances.mapValues(HeteroLRGuest.load_data)
        LOGGER.debug(
            f"MODEL_STEP After load data, data count: {data_instances.count()}"
        )
        self.cipher_operator = self.cipher.gen_paillier_cipher_operator()

        LOGGER.info("Generate mini-batch from input data")
        self.batch_generator.initialize_batch_generator(
            data_instances, self.batch_size)
        self.gradient_loss_operator.set_total_batch_nums(
            self.batch_generator.batch_nums)

        self.encrypted_calculator = [
            EncryptModeCalculator(
                self.cipher_operator,
                self.encrypted_mode_calculator_param.mode,
                self.encrypted_mode_calculator_param.re_encrypted_rate)
            for _ in range(self.batch_generator.batch_nums)
        ]

        LOGGER.info("Start initialize model.")
        LOGGER.info("fit_intercept:{}".format(
            self.init_param_obj.fit_intercept))
        model_shape = self.get_features_shape(data_instances)
        w = self.initializer.init_model(model_shape,
                                        init_params=self.init_param_obj)
        self.model_weights = LinearModelWeights(
            w, fit_intercept=self.fit_intercept)

        while self.n_iter_ < self.max_iter:
            LOGGER.info("iter:{}".format(self.n_iter_))
            batch_data_generator = self.batch_generator.generate_batch_data()
            self.optimizer.set_iters(self.n_iter_)
            batch_index = 0
            for batch_data in batch_data_generator:
                # transforms features of raw input 'batch_data_inst' into more representative features 'batch_feat_inst'
                batch_feat_inst = batch_data
                # LOGGER.debug(f"MODEL_STEP In Batch {batch_index}, batch data count: {batch_feat_inst.count()}")

                # Start gradient procedure
                LOGGER.debug(
                    "iter: {}, before compute gradient, data count: {}".format(
                        self.n_iter_, batch_feat_inst.count()))
                optim_guest_gradient = self.gradient_loss_operator.compute_gradient_procedure(
                    batch_feat_inst, self.encrypted_calculator,
                    self.model_weights, self.optimizer, self.n_iter_,
                    batch_index)

                # LOGGER.debug('optim_guest_gradient: {}'.format(optim_guest_gradient))
                # training_info = {"iteration": self.n_iter_, "batch_index": batch_index}
                # self.update_local_model(fore_gradient, data_instances, self.model_weights.coef_, **training_info)

                loss_norm = self.optimizer.loss_norm(self.model_weights)
                self.gradient_loss_operator.compute_loss(
                    data_instances, self.model_weights, self.n_iter_,
                    batch_index, loss_norm)

                self.model_weights = self.optimizer.update_model(
                    self.model_weights, optim_guest_gradient)
                batch_index += 1
                # LOGGER.debug("lr_weight, iters: {}, update_model: {}".format(self.n_iter_, self.model_weights.unboxed))

            self.is_converged = self.converge_procedure.sync_converge_info(
                suffix=(self.n_iter_, ))
            LOGGER.info("iter: {},  is_converged: {}".format(
                self.n_iter_, self.is_converged))

            if self.validation_strategy:
                LOGGER.debug('LR guest running validation')
                self.validation_strategy.validate(self, self.n_iter_)
                if self.validation_strategy.need_stop():
                    LOGGER.debug('early stopping triggered')
                    break

            self.n_iter_ += 1

            if self.is_converged:
                break

        if self.validation_strategy and self.validation_strategy.has_saved_best_model(
        ):
            self.load_model(self.validation_strategy.cur_best_model)
        self.set_summary(self.get_model_summary())
Пример #9
0
    def unified_calculation_process(self, data_instances):
        LOGGER.info("RSA intersect using unified calculation.")
        # generate rsa keys
        # self.e, self.d, self.n = self.generate_protocol_key()
        self.generate_protocol_key()
        LOGGER.info("Generate protocol key!")
        public_key = {"e": self.e, "n": self.n}

        # sends public key e & n to guest
        self.transfer_variable.host_pubkey.remote(public_key,
                                                  role=consts.GUEST,
                                                  idx=0)
        LOGGER.info("Remote public key to Guest.")
        # hash host ids
        prvkey_ids_process_pair = self.cal_prvkey_ids_process_pair(
            data_instances, self.d, self.n, self.p, self.q, self.cp, self.cq,
            self.first_hash_operator)

        prvkey_ids_process = prvkey_ids_process_pair.mapValues(lambda v: 1)
        self.transfer_variable.host_prvkey_ids.remote(prvkey_ids_process,
                                                      role=consts.GUEST,
                                                      idx=0)
        LOGGER.info("Remote host_ids_process to Guest.")

        # Recv guest ids
        guest_pubkey_ids = self.transfer_variable.guest_pubkey_ids.get(idx=0)
        LOGGER.info("Get guest_pubkey_ids from guest")

        # Process(signs) guest ids and return to guest
        host_sign_guest_ids = guest_pubkey_ids.map(lambda k, v: (
            k, self.sign_id(k, self.d, self.n, self.p, self.q, self.cp, self.cq
                            )))
        self.transfer_variable.host_sign_guest_ids.remote(host_sign_guest_ids,
                                                          role=consts.GUEST,
                                                          idx=0)
        LOGGER.info("Remote host_sign_guest_ids_process to Guest.")

        # recv intersect ids
        intersect_ids = None
        if self.sync_intersect_ids:
            encrypt_intersect_ids = self.transfer_variable.intersect_ids.get(
                idx=0)
            intersect_ids_pair = encrypt_intersect_ids.join(
                prvkey_ids_process_pair, lambda e, h: h)
            intersect_ids = intersect_ids_pair.map(lambda k, v: (v, "id"))
            LOGGER.info("Get intersect ids from Guest")

        return intersect_ids
Пример #10
0
    def fit(self, data_instances, validate_data=None):
        """
        Train linR model of role guest
        Parameters
        ----------
        data_instances: DTable of Instance, input data
        """

        LOGGER.info("Enter hetero_linR_guest fit")
        self._abnormal_detection(data_instances)
        self.header = self.get_header(data_instances)

        self.validation_strategy = self.init_validation_strategy(
            data_instances, validate_data)

        self.cipher_operator = self.cipher.gen_paillier_cipher_operator()

        LOGGER.info("Generate mini-batch from input data")
        self.batch_generator.initialize_batch_generator(
            data_instances, self.batch_size)
        self.gradient_loss_operator.set_total_batch_nums(
            self.batch_generator.batch_nums)

        self.encrypted_calculator = [
            EncryptModeCalculator(
                self.cipher_operator,
                self.encrypted_mode_calculator_param.mode,
                self.encrypted_mode_calculator_param.re_encrypted_rate)
            for _ in range(self.batch_generator.batch_nums)
        ]

        LOGGER.info("Start initialize model.")
        LOGGER.info("fit_intercept:{}".format(
            self.init_param_obj.fit_intercept))
        model_shape = self.get_features_shape(data_instances)
        w = self.initializer.init_model(model_shape,
                                        init_params=self.init_param_obj)
        self.model_weights = LinearModelWeights(
            w, fit_intercept=self.fit_intercept)

        while self.n_iter_ < self.max_iter:
            LOGGER.info("iter:{}".format(self.n_iter_))
            # each iter will get the same batch_data_generator
            batch_data_generator = self.batch_generator.generate_batch_data()
            self.optimizer.set_iters(self.n_iter_)
            batch_index = 0
            for batch_data in batch_data_generator:
                # transforms features of raw input 'batch_data_inst' into more representative features 'batch_feat_inst'
                batch_feat_inst = self.transform(batch_data)

                # Start gradient procedure
                optim_guest_gradient, _, _ = self.gradient_loss_operator.compute_gradient_procedure(
                    batch_feat_inst, self.encrypted_calculator,
                    self.model_weights, self.optimizer, self.n_iter_,
                    batch_index)

                loss_norm = self.optimizer.loss_norm(self.model_weights)
                self.gradient_loss_operator.compute_loss(
                    data_instances, self.n_iter_, batch_index, loss_norm)

                self.model_weights = self.optimizer.update_model(
                    self.model_weights, optim_guest_gradient)
                batch_index += 1
                # LOGGER.debug(
                #     "model_weights, iters: {}, update_model: {}".format(self.n_iter_, self.model_weights.unboxed))

            self.is_converged = self.converge_procedure.sync_converge_info(
                suffix=(self.n_iter_, ))
            LOGGER.info("iter: {}, is_converged: {}".format(
                self.n_iter_, self.is_converged))

            # LOGGER.debug("model weights is {}".format(self.model_weights.coef_))

            if self.validation_strategy:
                LOGGER.debug('LinR guest running validation')
                self.validation_strategy.validate(self, self.n_iter_)
                if self.validation_strategy.need_stop():
                    LOGGER.debug('early stopping triggered')
                    break

            self.n_iter_ += 1
            if self.is_converged:
                break
        if self.validation_strategy and self.validation_strategy.has_saved_best_model(
        ):
            self.load_model(self.validation_strategy.cur_best_model)
        self.set_summary(self.get_model_summary())
Пример #11
0
 def get_grad_and_hess(g_h, dim=0):
     LOGGER.info("get grad and hess of tree {}".format(dim))
     grad_and_hess_subtree = g_h.mapValues(lambda grad_and_hess: (
         grad_and_hess[0][dim], grad_and_hess[1][dim]))
     return grad_and_hess_subtree
Пример #12
0
    def intersect_online_process(self, data_inst, caches):
        # LOGGER.debug(f"caches is: {caches}")
        cache_data, cache_meta = list(caches.values())[0]
        intersect_meta = list(cache_meta.values())[0]["intersect_meta"]
        # LOGGER.debug(f"intersect_meta is: {intersect_meta}")
        self.callback_cache_meta(intersect_meta)
        self.load_intersect_meta(intersect_meta)
        self.init_intersect_method()
        self.intersection_obj.load_intersect_key(cache_meta)

        if data_overview.check_with_inst_id(data_inst):
            self.use_match_id_process = True
            LOGGER.info(f"use match_id_process")
        intersect_data = data_inst
        if self.use_match_id_process:
            if len(self.host_party_id_list) > 1 and self.model_param.sample_id_generator != consts.GUEST:
                raise ValueError("While multi-host, sample_id_generator should be guest.")
            if self.model_param.intersect_method == consts.RAW:
                if self.model_param.sample_id_generator != self.intersection_obj.join_role:
                    raise ValueError(f"When using raw intersect with match id process,"
                                     f"'join_role' should be same role as 'sample_id_generator'")
            else:
                if not self.model_param.sync_intersect_ids:
                    if self.model_param.sample_id_generator != consts.GUEST:
                        self.model_param.sample_id_generator = consts.GUEST
                        LOGGER.warning(f"when not sync_intersect_ids with match id process,"
                                       f"sample_id_generator is set to Guest")

            proc_obj = MatchIDIntersect(sample_id_generator=self.model_param.sample_id_generator, role=self.role)
            proc_obj.new_sample_id = self.model_param.new_sample_id
            if data_overview.check_with_inst_id(data_inst) or self.model_param.with_sample_id:
                proc_obj.use_sample_id()
            match_data = proc_obj.recover(data=data_inst)
            intersect_data = match_data

        if self.role == consts.HOST:
            cache_id = cache_meta[str(self.guest_party_id)].get("cache_id")
            self.transfer_variable.cache_id.remote(cache_id, role=consts.GUEST, idx=0)
            guest_cache_id = self.transfer_variable.cache_id.get(role=consts.GUEST, idx=0)
            if guest_cache_id != cache_id:
                raise ValueError(f"cache_id check failed. cache_id from host & guest must match.")
        elif self.role == consts.GUEST:
            for i, party_id in enumerate(self.host_party_id_list):
                cache_id = cache_meta[str(party_id)].get("cache_id")
                self.transfer_variable.cache_id.remote(cache_id,
                                                       role=consts.HOST,
                                                       idx=i)
                host_cache_id = self.transfer_variable.cache_id.get(role=consts.HOST, idx=i)
                if host_cache_id != cache_id:
                    raise ValueError(f"cache_id check failed. cache_id from host & guest must match.")
        else:
            raise ValueError(f"Role {self.role} cannot run intersection transform.")

        self.intersect_ids = self.intersection_obj.run_cache_intersect(intersect_data, cache_data)
        if self.use_match_id_process:
            if not self.model_param.sync_intersect_ids:
                self.intersect_ids = proc_obj.expand(self.intersect_ids,
                                                     match_data=match_data,
                                                     owner_only=True)
            else:
                self.intersect_ids = proc_obj.expand(self.intersect_ids, match_data=match_data)
            if self.intersect_ids and self.model_param.only_output_key:
                self.intersect_ids = self.intersect_ids.mapValues(lambda v: Instance(inst_id=v.inst_id))
                self.intersect_ids.schema = {"match_id_name": data_inst.schema["match_id_name"],
                                             "sid_name": data_inst.schema["sid_name"]}

        LOGGER.info("Finish intersection")

        if self.intersect_ids:
            data_count = data_inst.count()
            self.intersect_num = self.intersect_ids.count()
            self.intersect_rate = self.intersect_num / data_count
            self.unmatched_num = data_count - self.intersect_num
            self.unmatched_rate = 1 - self.intersect_rate

        self.set_summary(self.get_model_summary())
        self.callback()

        result_data = self.intersect_ids
        if not self.use_match_id_process:
            if not self.intersection_obj.only_output_key and result_data:
                result_data = self.intersection_obj.get_value_from_data(result_data, data_inst)
                self.intersect_ids.schema = result_data.schema
                LOGGER.debug(f"not only_output_key, restore value called")
            if self.intersection_obj.only_output_key and result_data:
                schema = {"sid_name": data_inst.schema["sid_name"]}
                result_data = result_data.mapValues(lambda v: 1)
                result_data.schema = schema
                self.intersect_ids.schema = schema

        if self.model_param.join_method == consts.LEFT_JOIN:
            result_data = self.__sync_join_id(data_inst, self.intersect_ids)
            result_data.schema = self.intersect_ids.schema

        return result_data
Пример #13
0
    def fit(self, data):
        if self.component_properties.caches:
            return self.intersect_online_process(data, self.component_properties.caches)
        self.init_intersect_method()
        if data_overview.check_with_inst_id(data):
            self.use_match_id_process = True
            LOGGER.info(f"use match_id_process")

        if self.use_match_id_process:
            if len(self.host_party_id_list) > 1 and self.model_param.sample_id_generator != consts.GUEST:
                raise ValueError("While multi-host, sample_id_generator should be guest.")
            if self.model_param.intersect_method == consts.RAW:
                if self.model_param.sample_id_generator != self.intersection_obj.join_role:
                    raise ValueError(f"When using raw intersect with match id process,"
                                     f"'join_role' should be same role as 'sample_id_generator'")
            else:
                if not self.model_param.sync_intersect_ids:
                    if self.model_param.sample_id_generator != consts.GUEST:
                        self.model_param.sample_id_generator = consts.GUEST
                        LOGGER.warning(f"when not sync_intersect_ids with match id process,"
                                         f"sample_id_generator is set to Guest")

            self.proc_obj = MatchIDIntersect(sample_id_generator=self.model_param.sample_id_generator, role=self.role)
            self.proc_obj.new_sample_id = self.model_param.new_sample_id
            if data_overview.check_with_inst_id(data) or self.model_param.with_sample_id:
                self.proc_obj.use_sample_id()
            match_data = self.proc_obj.recover(data=data)
            if self.intersection_obj.run_cache:
                self.cache_output = self.intersection_obj.generate_cache(match_data)
                intersect_meta = self.intersection_obj.get_intersect_method_meta()
                self.callback_cache_meta(intersect_meta)
                return data
            if self.intersection_obj.cardinality_only:
                self.intersection_obj.run_cardinality(match_data)
            else:
                intersect_data = match_data
                if self.model_param.run_preprocess:
                    intersect_data = self.run_preprocess(match_data)
                self.intersect_ids = self.intersection_obj.run_intersect(intersect_data)
        else:
            if self.intersection_obj.run_cache:
                self.cache_output = self.intersection_obj.generate_cache(data)
                intersect_meta = self.intersection_obj.get_intersect_method_meta()
                # LOGGER.debug(f"callback intersect meta is: {intersect_meta}")
                self.callback_cache_meta(intersect_meta)
                return data
            if self.intersection_obj.cardinality_only:
                self.intersection_obj.run_cardinality(data)
            else:
                intersect_data = data
                if self.model_param.run_preprocess:
                    intersect_data = self.run_preprocess(data)
                self.intersect_ids = self.intersection_obj.run_intersect(intersect_data)

        if self.intersection_obj.cardinality_only:
            if self.intersection_obj.intersect_num is not None:
                data_count = data.count()
                self.intersect_num = self.intersection_obj.intersect_num
                self.intersect_rate = self.intersect_num / data_count
                self.unmatched_num = data_count - self.intersect_num
                self.unmatched_rate = 1 - self.intersect_rate
            # self.model = self.intersection_obj.get_model()
            self.set_summary(self.get_model_summary())
            self.callback()
            return data

        if self.use_match_id_process:
            if self.model_param.sync_intersect_ids:
                self.intersect_ids = self.proc_obj.expand(self.intersect_ids, match_data=match_data)
            else:
                # self.intersect_ids = match_data
                self.intersect_ids = self.proc_obj.expand(self.intersect_ids,
                                                          match_data=match_data,
                                                          owner_only=True)
            if self.model_param.only_output_key and self.intersect_ids:
                self.intersect_ids = self.intersect_ids.mapValues(lambda v: Instance(inst_id=v.inst_id))
                self.intersect_ids.schema = {"match_id_name": data.schema["match_id_name"],
                                             "sid_name": data.schema["sid_name"]}

        LOGGER.info("Finish intersection")

        if self.intersect_ids:
            data_count = data.count()
            self.intersect_num = self.intersect_ids.count()
            self.intersect_rate = self.intersect_num / data_count
            self.unmatched_num = data_count - self.intersect_num
            self.unmatched_rate = 1 - self.intersect_rate

        self.set_summary(self.get_model_summary())
        self.callback()

        result_data = self.intersect_ids
        if not self.use_match_id_process and not self.intersection_obj.only_output_key and result_data:
            result_data = self.intersection_obj.get_value_from_data(result_data, data)
            LOGGER.debug(f"not only_output_key, restore value called")

        if self.model_param.join_method == consts.LEFT_JOIN:
            result_data = self.__sync_join_id(data, self.intersect_ids)
            result_data.schema = self.intersect_ids.schema

        return result_data
Пример #14
0
    def predict(self, data_inst, ret_format='std'):

        # standard format, leaf indices, raw score
        assert ret_format in ['std', 'leaf', 'raw'], 'illegal ret format'

        LOGGER.info('running prediction')
        cache_dataset_key = self.predict_data_cache.get_data_key(data_inst)

        processed_data = self.data_and_header_alignment(data_inst)

        last_round = self.predict_data_cache.predict_data_last_round(
            cache_dataset_key)

        self.sync_predict_round(last_round)

        rounds = len(self.boosting_model_list) // self.booster_dim
        trees = []
        LOGGER.debug(
            'round involved in prediction {}, last round is {}, data key {}'.
            format(list(range(last_round, rounds)), last_round,
                   cache_dataset_key))

        for idx in range(last_round, rounds):
            for booster_idx in range(self.booster_dim):
                tree = self.load_learner(
                    self.booster_meta,
                    self.boosting_model_list[idx * self.booster_dim +
                                             booster_idx], idx, booster_idx)
                trees.append(tree)

        predict_cache = None
        tree_num = len(trees)

        if last_round != 0:
            predict_cache = self.predict_data_cache.predict_data_at(
                cache_dataset_key, min(rounds, last_round))
            LOGGER.info('load predict cache of round {}'.format(
                min(rounds, last_round)))

        if tree_num == 0 and predict_cache is not None and not (ret_format
                                                                == 'leaf'):
            return self.score_to_predict_result(data_inst, predict_cache)

        if self.boosting_strategy == consts.MIX_TREE:
            predict_rs = mix_sbt_guest_predict(
                processed_data,
                self.hetero_sbt_transfer_variable,
                trees,
                self.learning_rate,
                self.init_score,
                self.booster_dim,
                predict_cache,
                pred_leaf=(ret_format == 'leaf'))
        else:
            if self.EINI_inference and not self.on_training:  # EINI is for inference stage
                sitename = self.role + ':' + str(
                    self.component_properties.local_partyid)
                predict_rs = EINI_guest_predict(
                    processed_data, trees, self.learning_rate, self.init_score,
                    self.booster_dim, self.encrypt_param.key_length,
                    self.hetero_sbt_transfer_variable, sitename,
                    self.component_properties.host_party_idlist, predict_cache,
                    False)
            else:
                predict_rs = sbt_guest_predict(
                    processed_data,
                    self.hetero_sbt_transfer_variable,
                    trees,
                    self.learning_rate,
                    self.init_score,
                    self.booster_dim,
                    predict_cache,
                    pred_leaf=(ret_format == 'leaf'))

        if ret_format == 'leaf':
            return predict_rs  # predict result is leaf position

        self.predict_data_cache.add_data(cache_dataset_key,
                                         predict_rs,
                                         cur_boosting_round=rounds)
        LOGGER.debug('adding predict rs {}'.format(predict_rs))
        LOGGER.debug('last round is {}'.format(
            self.predict_data_cache.predict_data_last_round(
                cache_dataset_key)))

        if ret_format == 'raw':
            return predict_rs
        else:
            return self.score_to_predict_result(data_inst, predict_rs)
Пример #15
0
 def sync_predict_data(self, predict_data, send_times):
     LOGGER.info("send predict data to host, sending times is {}".format(send_times))
     self.transfer_inst.predict_data.remote(predict_data,
                                            role=consts.HOST,
                                            idx=-1,
                                            suffix=(send_times,))
Пример #16
0
    def run_cardinality(self, data_instances):
        LOGGER.info(f"run cardinality_only with RSA")
        # generate rsa keys
        self.generate_protocol_key()
        LOGGER.info("Generate protocol key!")
        public_key = {"e": self.e, "n": self.n}

        # sends public key e & n to guest
        self.transfer_variable.host_pubkey.remote(public_key,
                                                  role=consts.GUEST,
                                                  idx=0)
        LOGGER.info("Remote public key to Guest.")
        # hash host ids
        prvkey_ids_process_pair = self.cal_prvkey_ids_process_pair(
            data_instances, self.d, self.n, self.p, self.q, self.cp, self.cq,
            self.first_hash_operator)

        filter = self.construct_filter(
            prvkey_ids_process_pair,
            false_positive_rate=self.intersect_preprocess_params.
            false_positive_rate,
            hash_method=self.intersect_preprocess_params.hash_method,
            random_state=self.intersect_preprocess_params.random_state)
        self.filter = filter
        self.transfer_variable.host_filter.remote(filter,
                                                  role=consts.GUEST,
                                                  idx=0)
        LOGGER.info("Remote host_filter to Guest.")

        # Recv guest ids
        guest_pubkey_ids = self.transfer_variable.guest_pubkey_ids.get(idx=0)
        LOGGER.info("Get guest_pubkey_ids from guest")

        # Process(signs) guest ids and return to guest
        host_sign_guest_ids = guest_pubkey_ids.map(lambda k, v: (
            k, self.sign_id(k, self.d, self.n, self.p, self.q, self.cp, self.cq
                            )))
        self.transfer_variable.host_sign_guest_ids.remote(host_sign_guest_ids,
                                                          role=consts.GUEST,
                                                          idx=0)
        LOGGER.info("Remote host_sign_guest_ids_process to Guest.")

        if self.sync_cardinality:
            self.intersect_num = self.transfer_variable.cardinality.get(idx=0)
            LOGGER.info("Got intersect cardinality from guest.")

        return data_instances
Пример #17
0
 def sync_data_predicted_by_host(self, send_times):
     LOGGER.info("get predicted data by host, recv times is {}".format(send_times))
     predict_data = self.transfer_inst.predict_data_by_host.get(idx=-1,
                                                                suffix=(send_times,))
     return predict_data
Пример #18
0
    def split_calculation_process(self, data_instances):
        LOGGER.info("RSA intersect using split calculation.")
        # split data
        sid_hash_odd = data_instances.filter(lambda k, v: k & 1)
        sid_hash_even = data_instances.filter(lambda k, v: not k & 1)
        # LOGGER.debug(f"sid_hash_odd count: {sid_hash_odd.count()},"
        #              f"odd fraction: {sid_hash_odd.count()/data_instances.count()}")

        # generate rsa keys
        # self.e, self.d, self.n = self.generate_protocol_key()
        self.generate_protocol_key()
        LOGGER.info("Generate host protocol key!")
        public_key = {"e": self.e, "n": self.n}

        # sends public key e & n to guest
        self.transfer_variable.host_pubkey.remote(public_key,
                                                  role=consts.GUEST,
                                                  idx=0)
        LOGGER.info("Remote public key to Guest.")

        # generate ri for even ids
        # count = sid_hash_even.count()
        # self.r = self.generate_r_base(self.random_bit, count, self.random_base_fraction)
        # LOGGER.info(f"Generate {len(self.r)} r values.")

        # receive guest key for even ids
        guest_public_key = self.transfer_variable.guest_pubkey.get(idx=0)
        # LOGGER.debug("Get guest_public_key:{} from Guest".format(guest_public_key))
        LOGGER.info(f"Get guest_public_key from Guest")

        self.rcv_e = int(guest_public_key["e"])
        self.rcv_n = int(guest_public_key["n"])

        # encrypt & send guest pubkey-encrypted odd ids
        pubkey_ids_process = self.pubkey_id_process(
            sid_hash_even,
            fraction=self.random_base_fraction,
            random_bit=self.random_bit,
            rsa_e=self.rcv_e,
            rsa_n=self.rcv_n)
        LOGGER.info(f"Finish pubkey_ids_process")
        mask_host_id = pubkey_ids_process.mapValues(lambda v: 1)
        self.transfer_variable.host_pubkey_ids.remote(mask_host_id,
                                                      role=consts.GUEST,
                                                      idx=0)
        LOGGER.info("Remote host_pubkey_ids to Guest")

        # encrypt & send prvkey-encrypted host odd ids to guest
        prvkey_ids_process_pair = self.cal_prvkey_ids_process_pair(
            sid_hash_odd, self.d, self.n, self.p, self.q, self.cp, self.cq)
        prvkey_ids_process = prvkey_ids_process_pair.mapValues(lambda v: 1)

        self.transfer_variable.host_prvkey_ids.remote(prvkey_ids_process,
                                                      role=consts.GUEST,
                                                      idx=0)
        LOGGER.info("Remote host_prvkey_ids to Guest.")

        # get & sign guest pubkey-encrypted odd ids
        guest_pubkey_ids = self.transfer_variable.guest_pubkey_ids.get(idx=0)
        LOGGER.info(f"Get guest_pubkey_ids from guest")
        host_sign_guest_ids = guest_pubkey_ids.map(lambda k, v: (
            k, self.sign_id(k, self.d, self.n, self.p, self.q, self.cp, self.cq
                            )))
        LOGGER.debug(f"host sign guest_pubkey_ids")
        # send signed guest odd ids
        self.transfer_variable.host_sign_guest_ids.remote(host_sign_guest_ids,
                                                          role=consts.GUEST,
                                                          idx=0)
        LOGGER.info("Remote host_sign_guest_ids_process to Guest.")

        # recv guest privkey-encrypted even ids
        guest_prvkey_ids = self.transfer_variable.guest_prvkey_ids.get(idx=0)
        LOGGER.info("Get guest_prvkey_ids")

        # receive guest-signed host even ids
        recv_guest_sign_host_ids = self.transfer_variable.guest_sign_host_ids.get(
            idx=0)
        LOGGER.info(f"Get guest_sign_host_ids from Guest.")
        guest_sign_host_ids = pubkey_ids_process.join(
            recv_guest_sign_host_ids, lambda g, r:
            (g[0],
             RsaIntersectionHost.hash(
                 gmpy2.divm(int(r), int(g[1]), self.rcv_n), self.
                 final_hash_operator, self.rsa_params.salt)))
        sid_guest_sign_host_ids = guest_sign_host_ids.map(lambda k, v:
                                                          (v[1], v[0]))

        encrypt_intersect_even_ids = sid_guest_sign_host_ids.join(
            guest_prvkey_ids, lambda sid, h: sid)

        # filter & send intersect even ids
        intersect_even_ids = self.filter_intersect_ids(
            [encrypt_intersect_even_ids])

        remote_intersect_even_ids = encrypt_intersect_even_ids.mapValues(
            lambda v: 1)
        self.transfer_variable.host_intersect_ids.remote(
            remote_intersect_even_ids, role=consts.GUEST, idx=0)
        LOGGER.info(f"Remote host intersect ids to Guest")

        # recv intersect ids
        intersect_ids = None
        if self.sync_intersect_ids:
            encrypt_intersect_odd_ids = self.transfer_variable.intersect_ids.get(
                idx=0)
            intersect_odd_ids_pair = encrypt_intersect_odd_ids.join(
                prvkey_ids_process_pair, lambda e, h: h)
            intersect_odd_ids = intersect_odd_ids_pair.map(lambda k, v: (v, 1))
            intersect_ids = intersect_odd_ids.union(intersect_even_ids)
            LOGGER.info("Get intersect ids from Guest")
        return intersect_ids
Пример #19
0
    def key_derivation(self, target_num):
        """
        Derive a list of keys for encryption and transmission
        :param target_num: N in k-N OT
        :return: List[ObliviousTransferKey]
        """
        LOGGER.info(
            "enter sender key derivation phase for target num = {}".format(
                target_num))
        # 1. Choose a random scalar (y) from Z^q, calculate S and T to verify its legality
        y, s, t = self._gen_legal_y_s_t()

        # 2. Send S to the receiver, if it is illegal addressed by the receiver, regenerate y, S, T
        attempt_count = 0
        while True:
            self.transfer_variable.s.remote(s,
                                            suffix=(attempt_count, ),
                                            role=consts.GUEST,
                                            idx=0)
            # federation.remote(obj=s,
            #                   name=self.transfer_variable.s.name,
            #                   tag=self.transfer_variable.generate_transferid(self.transfer_variable.s, attempt_count),
            #                   role=consts.GUEST,
            #                   idx=0)
            LOGGER.info(
                "sent S to guest for the {}-th time".format(attempt_count))
            s_legal = self.transfer_variable.s_legal.get(
                idx=0, suffix=(attempt_count, ))
            # s_legal = federation.get(name=self.transfer_variable.s_legal.name,
            #                          tag=self.transfer_variable.generate_transferid(self.transfer_variable.s_legal,
            #                                                                         attempt_count),
            #                          idx=0)
            if s_legal:
                LOGGER.info(
                    "receiver confirms the legality of S at {} attempt, will proceed"
                    .format(attempt_count))
                break
            else:
                LOGGER.info(
                    "receiver rejects this S at {} attempt, will regenerate S".
                    format(attempt_count))
                y, s, t = self._gen_legal_y_s_t()
                attempt_count += 1

        # 3. Wait for the receiver to hash S to get T
        LOGGER.info("waiting for the receiver to hash S to get T")

        # 4. Get R = cT + xG from the receiver, also init the MAC
        r = self.transfer_variable.r.get(idx=0)
        # r = federation.get(name=self.transfer_variable.r.name,
        #                    tag=self.transfer_variable.generate_transferid(self.transfer_variable.r),
        #                    idx=0)
        LOGGER.info("got from guest R = " + r.output())
        self._init_mac(s, r)

        # 5. MAC and output the key list
        key_list = []
        yt = self.tec_arithmetic.mul(scalar=y, a=t)  # yT
        yr = self.tec_arithmetic.mul(scalar=y, a=r)  # yR
        for i in range(target_num):
            iyt = self.tec_arithmetic.mul(scalar=i, a=yt)  # iyT
            diff = self.tec_arithmetic.sub(a=yr, b=iyt)  # yR - iyT
            key = self._mac_tec_element(diff)
            LOGGER.info("{}-th key generated".format(i))
            LOGGER.info("key before MAC = " + diff.output())
            LOGGER.info("key = {}".format(key))
            key_list.append(ObliviousTransferKey(i, key))

        LOGGER.info("all keys successfully generated")

        return key_list
Пример #20
0
    def fit(self, data_instances, validate_data=None):
        LOGGER.debug("Start data count: {}".format(data_instances.count()))

        self._abnormal_detection(data_instances)
        self.check_abnormal_values(data_instances)
        self.init_schema(data_instances)
        # validation_strategy = self.init_validation_strategy(data_instances, validate_data)

        pubkey = self.cipher.gen_paillier_pubkey(enable=self.use_encrypt, suffix=('fit',))
        if self.use_encrypt:
            self.cipher_operator.set_public_key(pubkey)

        self.model_weights = self._init_model_variables(data_instances)
        w = self.cipher_operator.encrypt_list(self.model_weights.unboxed)
        self.model_weights = LogisticRegressionWeights(w, self.model_weights.fit_intercept)

        # LOGGER.debug("After init, model_weights: {}".format(self.model_weights.unboxed))

        mini_batch_obj = MiniBatch(data_inst=data_instances, batch_size=self.batch_size)

        total_batch_num = mini_batch_obj.batch_nums

        if self.use_encrypt:
            re_encrypt_times = (total_batch_num - 1) // self.re_encrypt_batches + 1
            LOGGER.debug("re_encrypt_times is :{}, batch_size: {}, total_batch_num: {}, re_encrypt_batches: {}".format(
                re_encrypt_times, self.batch_size, total_batch_num, self.re_encrypt_batches))
            self.cipher.set_re_cipher_time(re_encrypt_times)

        total_data_num = data_instances.count()
        LOGGER.debug("Current data count: {}".format(total_data_num))

        model_weights = self.model_weights
        self.prev_round_weights = copy.deepcopy(model_weights)
        degree = 0
        while self.n_iter_ < self.max_iter + 1:
            batch_data_generator = mini_batch_obj.mini_batch_data_generator()

            if ((self.n_iter_ + 1) % self.aggregate_iters == 0) or self.n_iter_ == self.max_iter:
                weight = self.aggregator.aggregate_then_get(model_weights, degree=degree,
                                                            suffix=self.n_iter_)
                # LOGGER.debug("Before aggregate: {}, degree: {} after aggregated: {}".format(
                #     model_weights.unboxed / degree,
                #     degree,
                #     weight.unboxed))
                self.model_weights = LogisticRegressionWeights(weight.unboxed, self.fit_intercept)
                if not self.use_encrypt:
                    loss = self._compute_loss(data_instances, self.prev_round_weights)
                    self.aggregator.send_loss(loss, degree=degree, suffix=(self.n_iter_,))
                    LOGGER.info("n_iters: {}, loss: {}".format(self.n_iter_, loss))
                degree = 0
                self.is_converged = self.aggregator.get_converge_status(suffix=(self.n_iter_,))
                LOGGER.info("n_iters: {}, is_converge: {}".format(self.n_iter_, self.is_converged))
                if self.is_converged or self.n_iter_ == self.max_iter:
                    break
                model_weights = self.model_weights

            batch_num = 0
            for batch_data in batch_data_generator:
                n = batch_data.count()
                degree += n
                LOGGER.debug('before compute_gradient')
                f = functools.partial(self.gradient_operator.compute_gradient,
                                      coef=model_weights.coef_,
                                      intercept=model_weights.intercept_,
                                      fit_intercept=self.fit_intercept)
                grad = batch_data.applyPartitions(f).reduce(fate_operator.reduce_add)
                grad /= n
                if self.use_proximal:  # use additional proximal term
                    model_weights = self.optimizer.update_model(model_weights,
                                                                grad=grad,
                                                                has_applied=False,
                                                                prev_round_weights=self.prev_round_weights)
                else:
                    model_weights = self.optimizer.update_model(model_weights,
                                                                grad=grad,
                                                                has_applied=False)

                if self.use_encrypt and batch_num % self.re_encrypt_batches == 0:
                    LOGGER.debug("Before accept re_encrypted_model, batch_iter_num: {}".format(batch_num))
                    w = self.cipher.re_cipher(w=model_weights.unboxed,
                                              iter_num=self.n_iter_,
                                              batch_iter_num=batch_num)
                    model_weights = LogisticRegressionWeights(w, self.fit_intercept)
                batch_num += 1

            # validation_strategy.validate(self, self.n_iter_)
            self.n_iter_ += 1

        self.set_summary(self.get_model_summary())
        LOGGER.info("Finish Training task, total iters: {}".format(self.n_iter_))
Пример #21
0
 def sync_tree_node_queue(self, dep=-1):
     LOGGER.info("get tree node queue of depth {}".format(dep))
     self.cur_layer_nodes = self.transfer_inst.tree_node_queue.get(
         idx=0, suffix=(dep, ))
Пример #22
0
 def sync_node_positions(self, dep, idx=-1):
     LOGGER.info("send node positions of depth {}".format(dep))
     self.transfer_inst.node_positions.remote(self.inst2node_idx,
                                              role=consts.HOST,
                                              idx=idx,
                                              suffix=(dep,))
Пример #23
0
 def sync_tree(self, ):
     LOGGER.info("sync tree from guest")
     self.tree_node = self.transfer_inst.tree.get(idx=0)
Пример #24
0
 def sync_federated_best_splitinfo_host(self, federated_best_splitinfo_host, dep=-1, batch=-1, idx=-1):
     LOGGER.info("send federated best splitinfo of depth {}, batch {}".format(dep, batch))
     self.transfer_inst.federated_best_splitinfo_host.remote(federated_best_splitinfo_host,
                                                             role=consts.HOST,
                                                             idx=idx,
                                                             suffix=(dep, batch,))
Пример #25
0
 def sync_predict_data(self, recv_times):
     LOGGER.info(
         "srecv predict data to host, recv times is {}".format(recv_times))
     predict_data = self.transfer_inst.predict_data.get(
         idx=0, suffix=(recv_times, ))
     return predict_data
Пример #26
0
 def sync_final_split_host(self, dep=-1, batch=-1, idx=-1):
     LOGGER.info("get host final splitinfo of depth {}, batch {}".format(dep, batch))
     final_splitinfo_host = self.transfer_inst.final_splitinfo_host.get(idx=idx,
                                                                        suffix=(dep, batch,))
     return final_splitinfo_host if idx == -1 else [final_splitinfo_host]
Пример #27
0
 def get_grad_hess_sum(self, grad_and_hess_table):
     LOGGER.info("calculate the sum of grad and hess")
     grad, hess = grad_and_hess_table.reduce(lambda value1, value2: (value1[
         0] + value2[0], value1[1] + value2[1]))
     return grad, hess
Пример #28
0
 def sync_tree(self, idx=-1):
     LOGGER.info("sync tree to host")
     tree_nodes = self.remove_sensitive_info()
     self.transfer_inst.tree.remote(tree_nodes,
                                    role=consts.HOST,
                                    idx=idx)
Пример #29
0
 def load_model(self, model_meta=None, model_param=None):
     LOGGER.info("load tree model")
     self.set_model_meta(model_meta)
     self.set_model_param(model_param)
Пример #30
0
    def fit(self, data_instances=None, validate_data=None):
        """
        Fit OneVsRest model
        Parameters:
        ----------
        data_instances: Table of instances
        """
        if self.mode == consts.H**O:
            raise ValueError(
                "Currently, One vs Rest is not supported for h**o algorithm")

        LOGGER.info("mode is {}, role is {}, start to one_vs_rest fit".format(
            self.mode, self.role))

        LOGGER.info("Total classes:{}".format(self.classes))

        self.classifier.callback_one_vs_rest = True
        current_flow_id = self.classifier.flowid
        summary_dict = {}
        for label_index, label in enumerate(self.classes):
            LOGGER.info(
                "Start to train OneVsRest with label_index:{}, label:{}".
                format(label_index, label))
            classifier = copy.deepcopy(self.classifier)
            classifier.need_one_vs_rest = False
            classifier.set_flowid(".".join(
                [current_flow_id, "model_" + str(label_index)]))
            if self.has_label:
                header = data_instances.schema.get("header")
                data_instances_mask_label = self._mask_data_label(
                    data_instances, label=label)
                data_instances_mask_label.schema['header'] = header

                if validate_data is not None:
                    validate_mask_label_data = self._mask_data_label(
                        validate_data, label=label)
                    validate_mask_label_data.schema['header'] = header
                else:
                    validate_mask_label_data = validate_data
                LOGGER.info("finish mask label:{}".format(label))

                LOGGER.info("start classifier fit")
                classifier.fit_binary(data_instances_mask_label,
                                      validate_data=validate_mask_label_data)
            else:
                LOGGER.info("start classifier fit")
                classifier.fit_binary(data_instances,
                                      validate_data=validate_data)
            _summary = classifier.summary()
            _summary['one_vs_rest'] = True
            summary_dict[label] = _summary
            self.models.append(classifier)
            LOGGER.info("Finish model_{} training!".format(label_index))
        self.classifier.set_summary(summary_dict)