def _merge_heap(constraint=None, aim_var=0): next_id = max(bucket_dict.keys()) + 1 while aim_var > 0 and not min_heap.is_empty: min_node = min_heap.pop() left_bucket = min_node.left_bucket right_bucket = min_node.right_bucket # Some buckets may be already merged if left_bucket.idx not in bucket_dict or right_bucket.idx not in bucket_dict: continue new_bucket = bucket_info.Bucket( idx=next_id, adjustment_factor=optimal_param.adjustment_factor) new_bucket = _init_new_bucket(new_bucket, min_node) bucket_dict[next_id] = new_bucket del bucket_dict[left_bucket.idx] del bucket_dict[right_bucket.idx] min_heap.remove_empty_node(left_bucket.idx) min_heap.remove_empty_node(right_bucket.idx) aim_var = _aim_vars_decrease(constraint, new_bucket, left_bucket, right_bucket, aim_var) _add_node_from_new_bucket(new_bucket, constraint) next_id += 1 return min_heap, aim_var
def init_bucket(self, data_instances): header = data_overview.get_header(data_instances) self._default_setting(header) init_bucket_param = copy.deepcopy(self.params) init_bucket_param.bin_num = self.optimal_param.init_bin_nums if self.optimal_param.init_bucket_method == consts.QUANTILE: init_binning_obj = QuantileBinningTool(param_obj=init_bucket_param, allow_duplicate=False) else: init_binning_obj = BucketBinning(params=init_bucket_param) init_binning_obj.set_bin_inner_param(self.bin_inner_param) init_split_points = init_binning_obj.fit_split_points(data_instances) is_sparse = data_overview.is_sparse_data(data_instances) bucket_dict = dict() for col_name, sps in init_split_points.items(): bucket_list = [] for idx, sp in enumerate(sps): bucket = bucket_info.Bucket(idx, self.adjustment_factor, right_bound=sp) if idx == 0: bucket.left_bound = -math.inf bucket.set_left_neighbor(None) else: bucket.left_bound = sps[idx - 1] bucket.event_total = self.event_total bucket.non_event_total = self.non_event_total bucket_list.append(bucket) bucket_list[-1].set_right_neighbor(None) bucket_dict[col_name] = bucket_list # LOGGER.debug(f"col_name: {col_name}, length of sps: {len(sps)}, " # f"length of list: {len(bucket_list)}") convert_func = functools.partial( self.convert_data_to_bucket, split_points=init_split_points, headers=self.header, bucket_dict=copy.deepcopy(bucket_dict), is_sparse=is_sparse, get_bin_num_func=self.get_bin_num) bucket_table = data_instances.mapReducePartitions( convert_func, self.merge_bucket_list) # bucket_table = dict(bucket_table.collect()) # for k, v in bucket_table.items(): # LOGGER.debug(f"[feature] {k}, length of list: {len(v)}") # LOGGER.debug("bucket_table: {}, length: {}".format(type(bucket_table), len(bucket_table))) # bucket_table = [(k, v) for k, v in bucket_table.items()] # LOGGER.debug("bucket_table: {}, length: {}".format(type(bucket_table), len(bucket_table))) # bucket_table = session.parallelize(bucket_table, include_key=True, partition=data_instances.partitions) return bucket_table
def bin_sum_to_bucket_list(self, bin_sum, partitions): """ Convert bin sum result, which typically get from host, to bucket list Parameters ---------- bin_sum : dict {'x1': [[event_count, non_event_count], [event_count, non_event_count] ... ], 'x2': [[event_count, non_event_count], [event_count, non_event_count] ... ], ... } partitions: int Indicate partitions for created table. Returns ------- A DTable whose keys are feature names and values are bucket lists """ bucket_dict = dict() for col_name, bin_res_list in bin_sum.items(): # bucket_list = [bucket_info.Bucket(idx, self.adjustment_factor) for idx in range(len(bin_res_list))] # bucket_list[0].set_left_neighbor(None) # bucket_list[-1].set_right_neighbor(None) # for b_idx, bucket in enumerate(bucket_list): bucket_list = [] for b_idx in range(len(bin_res_list)): bucket = bucket_info.Bucket(b_idx, self.adjustment_factor) if b_idx == 0: bucket.set_left_neighbor(None) if b_idx == len(bin_res_list) - 1: bucket.set_right_neighbor(None) bucket.event_count = bin_res_list[b_idx][0] bucket.non_event_count = bin_res_list[b_idx][1] bucket.left_bound = b_idx - 1 bucket.right_bound = b_idx bucket.event_total = self.event_total bucket.non_event_total = self.non_event_total bucket_list.append(bucket) bucket_dict[col_name] = bucket_list result = [] for col_name, bucket_list in bucket_dict.items(): result.append((col_name, bucket_list)) result_table = session.parallelize(result, include_key=True, partition=partitions) return result_table
def init_bucket(self, data_instances): header = data_overview.get_header(data_instances) self._default_setting(header) init_bucket_param = copy.deepcopy(self.params) init_bucket_param.bin_num = self.optimal_param.init_bin_nums if self.optimal_param.init_bucket_method == consts.QUANTILE: init_binning_obj = QuantileBinningTool(param_obj=init_bucket_param, allow_duplicate=False) else: init_binning_obj = BucketBinning(params=init_bucket_param) init_binning_obj.set_bin_inner_param(self.bin_inner_param) init_split_points = init_binning_obj.fit_split_points(data_instances) is_sparse = data_overview.is_sparse_data(data_instances) bucket_dict = dict() for col_name, sps in init_split_points.items(): # bucket_list = [bucket_info.Bucket(idx, self.adjustment_factor, right_bound=sp) # for idx, sp in enumerate(sps)] bucket_list = [] for idx, sp in enumerate(sps): bucket = bucket_info.Bucket(idx, self.adjustment_factor, right_bound=sp) if idx == 0: bucket.left_bound = -math.inf bucket.set_left_neighbor(None) else: bucket.left_bound = sps[idx - 1] bucket.event_total = self.event_total bucket.non_event_total = self.non_event_total bucket_list.append(bucket) bucket_list[-1].set_right_neighbor(None) bucket_dict[col_name] = bucket_list LOGGER.debug(f"col_name: {col_name}, length of sps: {len(sps)}, " f"length of list: {len(bucket_list)}") # bucket_table = data_instances.mapPartitions2(convert_func) # bucket_table = bucket_table.reduce(self.merge_bucket_list, key_func=lambda key: key[1]) from fate_arch.common.versions import get_eggroll_version version = get_eggroll_version() if version.startswith('2.0'): convert_func = functools.partial(self.convert_data_to_bucket_old, split_points=init_split_points, headers=self.header, bucket_dict=copy.deepcopy(bucket_dict), is_sparse=is_sparse, get_bin_num_func=self.get_bin_num) summary_dict = data_instances.mapPartitions(convert_func, use_previous_behavior=False) # summary_dict = summary_dict.reduce(self.copy_merge, key_func=lambda key: key[1]) from federatedml.util.reduce_by_key import reduce bucket_table = reduce(summary_dict, self.merge_bucket_list, key_func=lambda key: key[1]) elif version.startswith('2.2'): convert_func = functools.partial(self.convert_data_to_bucket, split_points=init_split_points, headers=self.header, bucket_dict=copy.deepcopy(bucket_dict), is_sparse=is_sparse, get_bin_num_func=self.get_bin_num) bucket_table = data_instances.mapReducePartitions(convert_func, self.merge_bucket_list) bucket_table = dict(bucket_table.collect()) else: raise RuntimeError(f"Cannot recognized eggroll version: {version}") for k, v in bucket_table.items(): LOGGER.debug(f"[feature] {k}, length of list: {len(v)}") LOGGER.debug("bucket_table: {}, length: {}".format(type(bucket_table), len(bucket_table))) bucket_table = [(k, v) for k, v in bucket_table.items()] LOGGER.debug("bucket_table: {}, length: {}".format(type(bucket_table), len(bucket_table))) bucket_table = session.parallelize(bucket_table, include_key=True, partition=data_instances.partitions) return bucket_table