Пример #1
0
def load_hetero_tree_learner(role,
                             tree_param,
                             model_meta,
                             model_param,
                             flow_id,
                             runtime_idx,
                             host_party_list=None,
                             fast_sbt=False,
                             tree_type=None,
                             target_host_id=None):
    if role == consts.HOST:

        if fast_sbt:
            tree = HeteroFastDecisionTreeHost(tree_param)
        else:
            tree = HeteroDecisionTreeHost(tree_param)

        tree.load_model(model_meta, model_param)
        tree.set_flowid(flow_id)
        tree.set_runtime_idx(runtime_idx)

        if fast_sbt:
            tree.set_tree_work_mode(tree_type, target_host_id)
            tree.set_self_host_id(runtime_idx)

    elif role == consts.GUEST:

        if fast_sbt:
            tree = HeteroFastDecisionTreeGuest(tree_param)
        else:
            tree = HeteroDecisionTreeGuest(tree_param)

        tree.load_model(model_meta, model_param)
        tree.set_flowid(flow_id)
        tree.set_runtime_idx(runtime_idx)
        tree.set_host_party_idlist(host_party_list)

        if fast_sbt:
            tree.set_tree_work_mode(tree_type, target_host_id)

    else:
        raise ValueError('unknown role: {}'.format(role))

    return tree
Пример #2
0
    def fit_a_booster(self, epoch_idx: int, booster_dim: int):

        # prepare tree plan
        tree_type, target_host_id = self.get_tree_plan(epoch_idx)
        LOGGER.info('tree work mode is {}'.format(tree_type))
        self.check_host_number(tree_type)

        if self.cur_epoch_idx != epoch_idx:
            # update g/h every epoch
            self.grad_and_hess = self.compute_grad_and_hess(self.y_hat, self.y)
            self.cur_epoch_idx = epoch_idx

        g_h = self.get_grad_and_hess(self.grad_and_hess, booster_dim)

        tree = HeteroFastDecisionTreeGuest(tree_param=self.tree_param)
        tree.set_input_data(self.data_bin, self.bin_split_points,
                            self.bin_sparse_points)
        tree.set_grad_and_hess(g_h)
        tree.set_encrypter(self.encrypter)
        tree.set_encrypted_mode_calculator(self.encrypted_calculator)
        tree.set_valid_features(self.sample_valid_features())
        tree.set_flowid(self.generate_flowid(epoch_idx, booster_dim))
        tree.set_host_party_idlist(self.component_properties.host_party_idlist)
        tree.set_runtime_idx(self.component_properties.local_partyid)
        tree.set_tree_work_mode(tree_type, target_host_id)
        tree.set_layered_depth(self.guest_depth, self.host_depth)
        tree.fit()
        self.update_feature_importance(tree.get_feature_importance())
        # tree.print_leafs()
        return tree
Пример #3
0
    def load_booster(self, model_meta, model_param, epoch_idx, booster_idx):

        tree = HeteroFastDecisionTreeGuest(self.tree_param)
        tree.load_model(model_meta, model_param)
        tree.set_flowid(self.generate_flowid(epoch_idx, booster_idx))
        tree.set_runtime_idx(self.component_properties.local_partyid)
        tree.set_host_party_idlist(self.component_properties.host_party_idlist)

        tree_type, target_host_id = self.get_tree_plan(epoch_idx)
        tree.set_tree_work_mode(tree_type, target_host_id)

        if self.tree_plan[epoch_idx][0] == plan.tree_type_dict[
                'guest_feat_only']:
            LOGGER.debug('tree of epoch {} is guest only'.format(epoch_idx))
            tree.use_guest_feat_only_predict_mode()

        return tree
Пример #4
0
def produce_hetero_tree_learner(
        role,
        tree_param: DecisionTreeParam,
        flow_id,
        data_bin,
        bin_split_points,
        bin_sparse_points,
        task_type,
        valid_features,
        host_party_list,
        runtime_idx,
        cipher_compress=True,
        mo_tree=False,
        class_num=1,
        g_h=None,
        encrypter=None,  # guest only
        goss_subsample=False,
        complete_secure=False,
        max_sample_weights=1.0,
        bin_num=None,  # host only
        fast_sbt=False,
        tree_type=None,
        target_host_id=None,  # fast sbt only
        guest_depth=2,
        host_depth=3  # fast sbt only
):
    if role == consts.GUEST:
        if not fast_sbt:
            tree = HeteroDecisionTreeGuest(tree_param)
        else:
            tree = HeteroFastDecisionTreeGuest(tree_param)
            tree.set_tree_work_mode(tree_type, target_host_id)
            tree.set_layered_depth(guest_depth, host_depth)

        tree.init(flowid=flow_id,
                  data_bin=data_bin,
                  bin_split_points=bin_split_points,
                  bin_sparse_points=bin_sparse_points,
                  grad_and_hess=g_h,
                  encrypter=encrypter,
                  task_type=task_type,
                  valid_features=valid_features,
                  host_party_list=host_party_list,
                  runtime_idx=runtime_idx,
                  goss_subsample=goss_subsample,
                  complete_secure=complete_secure,
                  cipher_compressing=cipher_compress,
                  max_sample_weight=max_sample_weights,
                  mo_tree=mo_tree,
                  class_num=class_num)

    elif role == consts.HOST:
        if not fast_sbt:
            tree = HeteroDecisionTreeHost(tree_param)
        else:
            tree = HeteroFastDecisionTreeHost(tree_param)
            tree.set_tree_work_mode(tree_type, target_host_id)
            tree.set_layered_depth(guest_depth, host_depth)
            tree.set_self_host_id(runtime_idx)
            tree.set_host_party_idlist(host_party_list)

        tree.init(flowid=flow_id,
                  valid_features=valid_features,
                  data_bin=data_bin,
                  bin_split_points=bin_split_points,
                  bin_sparse_points=bin_sparse_points,
                  runtime_idx=runtime_idx,
                  goss_subsample=goss_subsample,
                  complete_secure=complete_secure,
                  cipher_compressing=cipher_compress,
                  bin_num=bin_num,
                  mo_tree=mo_tree)

    else:
        raise ValueError('unknown role: {}'.format(role))

    return tree
    def fit_a_booster(self, epoch_idx: int, booster_dim: int):

        # prepare tree plan
        tree_type, target_host_id = self.get_tree_plan(epoch_idx)
        LOGGER.info('tree work mode is {}'.format(tree_type))
        self.check_host_number(tree_type)

        if self.cur_epoch_idx != epoch_idx:
            # update g/h every epoch
            self.grad_and_hess = self.compute_grad_and_hess(
                self.y_hat, self.y, self.data_inst)
            self.cur_epoch_idx = epoch_idx

        g_h = self.get_grad_and_hess(self.grad_and_hess, booster_dim)

        tree = HeteroFastDecisionTreeGuest(tree_param=self.tree_param)
        tree.init(
            flowid=self.generate_flowid(epoch_idx, booster_dim),
            data_bin=self.data_bin,
            bin_split_points=self.bin_split_points,
            bin_sparse_points=self.bin_sparse_points,
            grad_and_hess=g_h,
            encrypter=self.encrypter,
            encrypted_mode_calculator=self.encrypted_calculator,
            valid_features=self.sample_valid_features(),
            host_party_list=self.component_properties.host_party_idlist,
            runtime_idx=self.component_properties.local_partyid,
            goss_subsample=self.enable_goss,
            top_rate=self.top_rate,
            other_rate=self.other_rate,
            complete_secure=True if
            (self.cur_epoch_idx == 0 and self.complete_secure) else False,
            cipher_compressing=self.round_decimal is not None,
            round_decimal=self.round_decimal,
            encrypt_key_length=self.encrypt_param.key_length,
            max_sample_weight=self.max_sample_weight,
            new_ver=self.new_ver)
        tree.set_tree_work_mode(tree_type, target_host_id)
        tree.set_layered_depth(self.guest_depth, self.host_depth)
        tree.fit()
        self.update_feature_importance(tree.get_feature_importance())
        return tree