Exemple #1
0
        def _merge_heap(constraint=None, aim_var=0):
            next_id = max(bucket_dict.keys()) + 1
            while aim_var > 0 and not min_heap.is_empty:
                min_node = min_heap.pop()
                left_bucket = min_node.left_bucket
                right_bucket = min_node.right_bucket

                # Some buckets may be already merged
                if left_bucket.idx not in bucket_dict or right_bucket.idx not in bucket_dict:
                    continue
                new_bucket = bucket_info.Bucket(
                    idx=next_id,
                    adjustment_factor=optimal_param.adjustment_factor)
                new_bucket = _init_new_bucket(new_bucket, min_node)
                bucket_dict[next_id] = new_bucket
                del bucket_dict[left_bucket.idx]
                del bucket_dict[right_bucket.idx]
                min_heap.remove_empty_node(left_bucket.idx)
                min_heap.remove_empty_node(right_bucket.idx)

                aim_var = _aim_vars_decrease(constraint, new_bucket,
                                             left_bucket, right_bucket,
                                             aim_var)
                _add_node_from_new_bucket(new_bucket, constraint)
                next_id += 1
            return min_heap, aim_var
Exemple #2
0
    def init_bucket(self, data_instances):
        header = data_overview.get_header(data_instances)
        self._default_setting(header)

        init_bucket_param = copy.deepcopy(self.params)
        init_bucket_param.bin_num = self.optimal_param.init_bin_nums
        if self.optimal_param.init_bucket_method == consts.QUANTILE:
            init_binning_obj = QuantileBinningTool(param_obj=init_bucket_param,
                                                   allow_duplicate=False)
        else:
            init_binning_obj = BucketBinning(params=init_bucket_param)
        init_binning_obj.set_bin_inner_param(self.bin_inner_param)
        init_split_points = init_binning_obj.fit_split_points(data_instances)
        is_sparse = data_overview.is_sparse_data(data_instances)

        bucket_dict = dict()
        for col_name, sps in init_split_points.items():

            bucket_list = []
            for idx, sp in enumerate(sps):
                bucket = bucket_info.Bucket(idx,
                                            self.adjustment_factor,
                                            right_bound=sp)
                if idx == 0:
                    bucket.left_bound = -math.inf
                    bucket.set_left_neighbor(None)
                else:
                    bucket.left_bound = sps[idx - 1]
                bucket.event_total = self.event_total
                bucket.non_event_total = self.non_event_total
                bucket_list.append(bucket)
            bucket_list[-1].set_right_neighbor(None)
            bucket_dict[col_name] = bucket_list
            # LOGGER.debug(f"col_name: {col_name}, length of sps: {len(sps)}, "
            #              f"length of list: {len(bucket_list)}")

        convert_func = functools.partial(
            self.convert_data_to_bucket,
            split_points=init_split_points,
            headers=self.header,
            bucket_dict=copy.deepcopy(bucket_dict),
            is_sparse=is_sparse,
            get_bin_num_func=self.get_bin_num)
        bucket_table = data_instances.mapReducePartitions(
            convert_func, self.merge_bucket_list)
        # bucket_table = dict(bucket_table.collect())

        # for k, v in bucket_table.items():
        #     LOGGER.debug(f"[feature] {k}, length of list: {len(v)}")

        # LOGGER.debug("bucket_table: {}, length: {}".format(type(bucket_table), len(bucket_table)))
        # bucket_table = [(k, v) for k, v in bucket_table.items()]
        # LOGGER.debug("bucket_table: {}, length: {}".format(type(bucket_table), len(bucket_table)))

        # bucket_table = session.parallelize(bucket_table, include_key=True, partition=data_instances.partitions)

        return bucket_table
    def bin_sum_to_bucket_list(self, bin_sum, partitions):
        """
        Convert bin sum result, which typically get from host, to bucket list
        Parameters
        ----------
        bin_sum : dict
           {'x1': [[event_count, non_event_count], [event_count, non_event_count] ... ],
             'x2': [[event_count, non_event_count], [event_count, non_event_count] ... ],
             ...
            }

        partitions: int
            Indicate partitions for created table.

        Returns
        -------
        A DTable whose keys are feature names and values are bucket lists
        """
        bucket_dict = dict()
        for col_name, bin_res_list in bin_sum.items():
            # bucket_list = [bucket_info.Bucket(idx, self.adjustment_factor) for idx in range(len(bin_res_list))]
            # bucket_list[0].set_left_neighbor(None)
            # bucket_list[-1].set_right_neighbor(None)
            # for b_idx, bucket in enumerate(bucket_list):
            bucket_list = []
            for b_idx in range(len(bin_res_list)):
                bucket = bucket_info.Bucket(b_idx, self.adjustment_factor)
                if b_idx == 0:
                    bucket.set_left_neighbor(None)
                if b_idx == len(bin_res_list) - 1:
                    bucket.set_right_neighbor(None)
                bucket.event_count = bin_res_list[b_idx][0]
                bucket.non_event_count = bin_res_list[b_idx][1]
                bucket.left_bound = b_idx - 1
                bucket.right_bound = b_idx
                bucket.event_total = self.event_total
                bucket.non_event_total = self.non_event_total
                bucket_list.append(bucket)
            bucket_dict[col_name] = bucket_list

        result = []
        for col_name, bucket_list in bucket_dict.items():
            result.append((col_name, bucket_list))
        result_table = session.parallelize(result,
                                           include_key=True,
                                           partition=partitions)
        return result_table
Exemple #4
0
    def init_bucket(self, data_instances):
        header = data_overview.get_header(data_instances)
        self._default_setting(header)

        init_bucket_param = copy.deepcopy(self.params)
        init_bucket_param.bin_num = self.optimal_param.init_bin_nums
        if self.optimal_param.init_bucket_method == consts.QUANTILE:
            init_binning_obj = QuantileBinningTool(param_obj=init_bucket_param, allow_duplicate=False)
        else:
            init_binning_obj = BucketBinning(params=init_bucket_param)
        init_binning_obj.set_bin_inner_param(self.bin_inner_param)
        init_split_points = init_binning_obj.fit_split_points(data_instances)
        is_sparse = data_overview.is_sparse_data(data_instances)

        bucket_dict = dict()
        for col_name, sps in init_split_points.items():

            # bucket_list = [bucket_info.Bucket(idx, self.adjustment_factor, right_bound=sp)
            #                for idx, sp in enumerate(sps)]
            bucket_list = []
            for idx, sp in enumerate(sps):
                bucket = bucket_info.Bucket(idx, self.adjustment_factor, right_bound=sp)
                if idx == 0:
                    bucket.left_bound = -math.inf
                    bucket.set_left_neighbor(None)
                else:
                    bucket.left_bound = sps[idx - 1]
                bucket.event_total = self.event_total
                bucket.non_event_total = self.non_event_total
                bucket_list.append(bucket)
            bucket_list[-1].set_right_neighbor(None)
            bucket_dict[col_name] = bucket_list
            LOGGER.debug(f"col_name: {col_name}, length of sps: {len(sps)}, "
                         f"length of list: {len(bucket_list)}")

        # bucket_table = data_instances.mapPartitions2(convert_func)
        # bucket_table = bucket_table.reduce(self.merge_bucket_list, key_func=lambda key: key[1])
        from fate_arch.common.versions import get_eggroll_version
        version = get_eggroll_version()
        if version.startswith('2.0'):
            convert_func = functools.partial(self.convert_data_to_bucket_old,
                                             split_points=init_split_points,
                                             headers=self.header,
                                             bucket_dict=copy.deepcopy(bucket_dict),
                                             is_sparse=is_sparse,
                                             get_bin_num_func=self.get_bin_num)
            summary_dict = data_instances.mapPartitions(convert_func, use_previous_behavior=False)
            # summary_dict = summary_dict.reduce(self.copy_merge, key_func=lambda key: key[1])
            from federatedml.util.reduce_by_key import reduce
            bucket_table = reduce(summary_dict, self.merge_bucket_list, key_func=lambda key: key[1])
        elif version.startswith('2.2'):
            convert_func = functools.partial(self.convert_data_to_bucket,
                                             split_points=init_split_points,
                                             headers=self.header,
                                             bucket_dict=copy.deepcopy(bucket_dict),
                                             is_sparse=is_sparse,
                                             get_bin_num_func=self.get_bin_num)
            bucket_table = data_instances.mapReducePartitions(convert_func, self.merge_bucket_list)
            bucket_table = dict(bucket_table.collect())
        else:
            raise RuntimeError(f"Cannot recognized eggroll version: {version}")

        for k, v in bucket_table.items():
            LOGGER.debug(f"[feature] {k}, length of list: {len(v)}")

        LOGGER.debug("bucket_table: {}, length: {}".format(type(bucket_table), len(bucket_table)))
        bucket_table = [(k, v) for k, v in bucket_table.items()]
        LOGGER.debug("bucket_table: {}, length: {}".format(type(bucket_table), len(bucket_table)))

        bucket_table = session.parallelize(bucket_table, include_key=True, partition=data_instances.partitions)

        return bucket_table