Beispiel #1
0
    def fit(self):
        """
        start to fit
        """
        LOGGER.info('begin to fit h**o decision tree, epoch {}, tree idx {},'
                    'running on distributed backend'.format(self.epoch_idx, self.tree_idx))

        self.init_root_node_and_gh_sum()
        LOGGER.debug('assign samples to root node')
        self.inst2node_idx = self.assign_instance_to_root_node(self.data_bin, 0)

        tree_height = self.max_depth + 1  # non-leaf node height + 1 layer leaf

        for dep in range(tree_height):

            if dep + 1 == tree_height:

                for node in self.cur_layer_node:
                    node.is_leaf = True
                    self.tree_node.append(node)

                rest_sample_leaf_pos = self.inst2node_idx.mapValues(lambda x: x[1])
                if self.sample_leaf_pos is None:
                    self.sample_leaf_pos = rest_sample_leaf_pos
                else:
                    self.sample_leaf_pos = self.sample_leaf_pos.union(rest_sample_leaf_pos)
                # stop fitting
                break

            LOGGER.debug('start to fit layer {}'.format(dep))

            table_with_assignment = self.update_instances_node_positions()

            # send current layer node number:
            self.sync_cur_layer_node_num(len(self.cur_layer_node), suffix=(dep, self.epoch_idx, self.tree_idx))

            split_info, agg_histograms = [], []
            for batch_id, i in enumerate(range(0, len(self.cur_layer_node), self.max_split_nodes)):
                cur_to_split = self.cur_layer_node[i:i+self.max_split_nodes]

                node_map = self.get_node_map(nodes=cur_to_split)
                LOGGER.debug('node map is {}'.format(node_map))
                LOGGER.debug('computing histogram for batch{} at depth{}'.format(batch_id, dep))
                local_histogram = self.get_left_node_local_histogram(
                    cur_nodes=cur_to_split,
                    tree=self.tree_node,
                    g_h=self.g_h,
                    table_with_assign=table_with_assignment,
                    split_points=self.bin_split_points,
                    sparse_point=self.bin_sparse_points,
                    valid_feature=self.valid_features
                )

                LOGGER.debug('federated finding best splits for batch{} at layer {}'.format(batch_id, dep))
                self.sync_local_node_histogram(local_histogram, suffix=(batch_id, dep, self.epoch_idx, self.tree_idx))

                agg_histograms += local_histogram

            split_info = self.sync_best_splits(suffix=(dep, self.epoch_idx))
            LOGGER.debug('got best splits from arbiter')

            new_layer_node = self.update_tree(self.cur_layer_node, split_info)
            self.cur_layer_node = new_layer_node

            self.inst2node_idx, leaf_val = self.assign_instances_to_new_node(table_with_assignment, self.tree_node)
            # record leaf val
            if self.sample_leaf_pos is None:
                self.sample_leaf_pos = leaf_val
            else:
                self.sample_leaf_pos = self.sample_leaf_pos.union(leaf_val)

            LOGGER.debug('assigning instance to new nodes done')

        self.convert_bin_to_real()
        self.sample_weights_post_process()
        LOGGER.debug('fitting tree done')
Beispiel #2
0
    def merge_optimal_binning(bucket_list, optimal_param: OptimalBinningParam, sample_count):
        t0 = time.time()
        max_item_num = math.floor(optimal_param.max_bin_pct * sample_count)
        min_item_num = math.ceil(optimal_param.min_bin_pct * sample_count)
        bucket_dict = {idx: bucket for idx, bucket in enumerate(bucket_list)}
        final_max_bin = optimal_param.max_bin

        LOGGER.debug("Get in merge optimal binning, sample_count: {}, max_item_num: {}, min_item_num: {},"
                     "final_max_bin: {}".format(sample_count, max_item_num, min_item_num, final_max_bin))
        min_heap = heap.MinHeap()

        def _add_heap_nodes(constraint=None):
            LOGGER.debug("Add heap nodes, constraint: {}, dict_length: {}".format(constraint, len(bucket_dict)))
            this_non_mixture_num = 0
            this_small_size_num = 0
            # Make bucket satisfy mixture condition

            # for i in bucket_dict.keys():
            for i in range(len(bucket_dict)):
                left_bucket = bucket_dict[i]
                right_bucket = bucket_dict.get(left_bucket.right_neighbor_idx)
                if left_bucket.right_neighbor_idx == i:
                    raise RuntimeError("left_bucket's right neighbor == itself")
                if not left_bucket.is_mixed:
                    this_non_mixture_num += 1

                if left_bucket.total_count < min_item_num:
                    this_small_size_num += 1

                if right_bucket is None:
                    continue
                # Violate maximum items constraint
                if left_bucket.total_count + right_bucket.total_count > max_item_num:
                    continue

                if constraint == 'mixture':
                    if left_bucket.is_mixed or right_bucket.is_mixed:
                        continue
                elif constraint == 'single_mixture':
                    if left_bucket.is_mixed and right_bucket.is_mixed:
                        continue

                elif constraint == 'small_size':
                    if left_bucket.total_count >= min_item_num or right_bucket.total_count >= min_item_num:
                        continue
                elif constraint == 'single_small_size':
                    if left_bucket.total_count >= min_item_num and right_bucket.total_count >= min_item_num:
                        continue
                heap_node = heap.heap_node_factory(optimal_param, left_bucket=left_bucket, right_bucket=right_bucket)
                min_heap.insert(heap_node)
            return min_heap, this_non_mixture_num, this_small_size_num

        def _update_bucket_info(b_dict):
            """
            update bucket information
            """
            order_dict = dict()
            for bucket_idx, item in b_dict.items():
                order_dict[bucket_idx] = item.left_bound

            sorted_order_dict = sorted(order_dict.items(), key=operator.itemgetter(1))

            start_idx = 0
            for item in sorted_order_dict:
                bucket_idx = item[0]
                if start_idx == bucket_idx:
                    start_idx += 1
                    continue

                b_dict[start_idx] = b_dict[bucket_idx]
                b_dict[start_idx].idx = start_idx
                start_idx += 1
                del b_dict[bucket_idx]

            bucket_num = len(b_dict)
            for i in range(bucket_num):
                if i == 0:
                    b_dict[i].set_left_neighbor(None)
                    b_dict[i].set_right_neighbor(i + 1)
                else:
                    b_dict[i].set_left_neighbor(i - 1)
                    b_dict[i].set_right_neighbor(i + 1)
            b_dict[bucket_num - 1].set_right_neighbor(None)
            # for b_dict_idx, bucket in b_dict.items():
            #     LOGGER.debug("After _update_bucket_info, b_dict_idx: {}, b_idx: {}".format(b_dict_idx, bucket.idx))
            return b_dict

        def _merge_heap(constraint=None, aim_var=0):
            next_id = max(bucket_dict.keys()) + 1
            while aim_var > 0 and not min_heap.is_empty:
                min_node = min_heap.pop()
                left_bucket = min_node.left_bucket
                right_bucket = min_node.right_bucket

                # Some buckets may be already merged
                if left_bucket.idx not in bucket_dict or right_bucket.idx not in bucket_dict:
                    continue
                new_bucket = bucket_info.Bucket(idx=next_id, adjustment_factor=optimal_param.adjustment_factor)
                new_bucket = _init_new_bucket(new_bucket, min_node)
                bucket_dict[next_id] = new_bucket
                del bucket_dict[left_bucket.idx]
                del bucket_dict[right_bucket.idx]
                min_heap.remove_empty_node(left_bucket.idx)
                min_heap.remove_empty_node(right_bucket.idx)

                aim_var = _aim_vars_decrease(constraint, new_bucket, left_bucket, right_bucket, aim_var)
                _add_node_from_new_bucket(new_bucket, constraint)
                next_id += 1
            return min_heap, aim_var

        def _add_node_from_new_bucket(new_bucket: bucket_info.Bucket, constraint):
            left_bucket = bucket_dict.get(new_bucket.left_neighbor_idx)
            right_bucket = bucket_dict.get(new_bucket.right_neighbor_idx)
            if constraint == 'mixture':
                if left_bucket is not None and left_bucket.total_count + new_bucket.total_count <= max_item_num:
                    if not left_bucket.is_mixed and not new_bucket.is_mixed:
                        heap_node = heap.heap_node_factory(optimal_param, left_bucket=left_bucket,
                                                           right_bucket=new_bucket)
                        min_heap.insert(heap_node)
                if right_bucket is not None and right_bucket.total_count + new_bucket.total_count <= max_item_num:
                    if not right_bucket.is_mixed and not new_bucket.is_mixed:
                        heap_node = heap.heap_node_factory(optimal_param, left_bucket=new_bucket,
                                                           right_bucket=right_bucket)
                        min_heap.insert(heap_node)

            elif constraint == 'single_mixture':
                if left_bucket is not None and left_bucket.total_count + new_bucket.total_count <= max_item_num:
                    if not (left_bucket.is_mixed and new_bucket.is_mixed):
                        heap_node = heap.heap_node_factory(optimal_param, left_bucket=left_bucket,
                                                           right_bucket=new_bucket)
                        min_heap.insert(heap_node)
                if right_bucket is not None and right_bucket.total_count + new_bucket.total_count <= max_item_num:
                    if not (right_bucket.is_mixed and new_bucket.is_mixed):
                        heap_node = heap.heap_node_factory(optimal_param, left_bucket=new_bucket,
                                                           right_bucket=right_bucket)
                        min_heap.insert(heap_node)

            elif constraint == 'small_size':
                if left_bucket is not None and left_bucket.total_count + new_bucket.total_count <= max_item_num:
                    if left_bucket.total_count < min_item_num and new_bucket.total_count < min_item_num:
                        heap_node = heap.heap_node_factory(optimal_param, left_bucket=left_bucket,
                                                           right_bucket=new_bucket)
                        min_heap.insert(heap_node)
                if right_bucket is not None and right_bucket.total_count + new_bucket.total_count <= max_item_num:
                    if right_bucket.total_count < min_item_num and new_bucket.total_count < min_item_num:
                        heap_node = heap.heap_node_factory(optimal_param, left_bucket=new_bucket,
                                                           right_bucket=right_bucket)
                        min_heap.insert(heap_node)

            elif constraint == 'single_small_size':
                if left_bucket is not None and left_bucket.total_count + new_bucket.total_count <= max_item_num:
                    if left_bucket.total_count < min_item_num or new_bucket.total_count < min_item_num:
                        heap_node = heap.heap_node_factory(optimal_param, left_bucket=left_bucket,
                                                           right_bucket=new_bucket)
                        min_heap.insert(heap_node)
                if right_bucket is not None and right_bucket.total_count + new_bucket.total_count <= max_item_num:
                    if right_bucket.total_count < min_item_num or new_bucket.total_count < min_item_num:
                        heap_node = heap.heap_node_factory(optimal_param, left_bucket=new_bucket,
                                                           right_bucket=right_bucket)
                        min_heap.insert(heap_node)
            else:
                if left_bucket is not None and left_bucket.total_count + new_bucket.total_count <= max_item_num:
                    heap_node = heap.heap_node_factory(optimal_param, left_bucket=left_bucket,
                                                       right_bucket=new_bucket)
                    min_heap.insert(heap_node)
                if right_bucket is not None and right_bucket.total_count + new_bucket.total_count <= max_item_num:
                    heap_node = heap.heap_node_factory(optimal_param, left_bucket=new_bucket,
                                                       right_bucket=right_bucket)
                    min_heap.insert(heap_node)

        def _init_new_bucket(new_bucket: bucket_info.Bucket, min_node: heap.HeapNode):
            new_bucket.left_bound = min_node.left_bucket.left_bound
            new_bucket.right_bound = min_node.right_bucket.right_bound
            new_bucket.left_neighbor_idx = min_node.left_bucket.left_neighbor_idx
            new_bucket.right_neighbor_idx = min_node.right_bucket.right_neighbor_idx
            new_bucket.event_count = min_node.left_bucket.event_count + min_node.right_bucket.event_count
            new_bucket.non_event_count = min_node.left_bucket.non_event_count + min_node.right_bucket.non_event_count
            new_bucket.event_total = min_node.left_bucket.event_total
            new_bucket.non_event_total = min_node.left_bucket.non_event_total

            left_neighbor_bucket = bucket_dict.get(new_bucket.left_neighbor_idx)
            if left_neighbor_bucket is not None:
                left_neighbor_bucket.right_neighbor_idx = new_bucket.idx

            right_neighbor_bucket = bucket_dict.get(new_bucket.right_neighbor_idx)
            if right_neighbor_bucket is not None:
                right_neighbor_bucket.left_neighbor_idx = new_bucket.idx
            return new_bucket

        def _aim_vars_decrease(constraint, new_bucket: bucket_info.Bucket, left_bucket, right_bucket, aim_var):
            if constraint in ['mixture', 'single_mixture']:
                if not left_bucket.is_mixed:
                    aim_var -= 1
                if not right_bucket.is_mixed:
                    aim_var -= 1
                if not new_bucket.is_mixed:
                    aim_var += 1
            elif constraint in ['small_size', 'single_small_size']:
                if left_bucket.total_count < min_item_num:
                    aim_var -= 1
                if right_bucket.total_count < min_item_num:
                    aim_var -= 1
                if new_bucket.total_count < min_item_num:
                    aim_var += 1
            else:
                aim_var = len(bucket_dict) - final_max_bin
            return aim_var

        if optimal_param.mixture:
            LOGGER.debug("Before mixture add, dick length: {}".format(len(bucket_dict)))
            min_heap, non_mixture_num, small_size_num = _add_heap_nodes(constraint='mixture')
            min_heap, non_mixture_num = _merge_heap(constraint='mixture', aim_var=non_mixture_num)
            bucket_dict = _update_bucket_info(bucket_dict)

            min_heap, non_mixture_num, small_size_num = _add_heap_nodes(constraint='single_mixture')
            min_heap, non_mixture_num = _merge_heap(constraint='single_mixture', aim_var=non_mixture_num)
            LOGGER.debug("After mixture merge, min_heap size: {}, non_mixture_num: {}".format(min_heap.size,
                                                                                              non_mixture_num))
            bucket_dict = _update_bucket_info(bucket_dict)

        LOGGER.debug("Before small_size add, dick length: {}".format(len(bucket_dict)))
        min_heap, non_mixture_num, small_size_num = _add_heap_nodes(constraint='small_size')
        min_heap, small_size_num = _merge_heap(constraint='small_size', aim_var=small_size_num)
        bucket_dict = _update_bucket_info(bucket_dict)

        min_heap, non_mixture_num, small_size_num = _add_heap_nodes(constraint='single_small_size')
        min_heap, small_size_num = _merge_heap(constraint='single_small_size', aim_var=small_size_num)

        bucket_dict = _update_bucket_info(bucket_dict)

        LOGGER.debug("Before add, dick length: {}".format(len(bucket_dict)))
        min_heap, non_mixture_num, small_size_num = _add_heap_nodes()
        LOGGER.debug("After normal add, small_size: {}, min_heap size: {}".format(small_size_num, min_heap.size))
        min_heap, total_bucket_num = _merge_heap(aim_var=len(bucket_dict) - final_max_bin)
        LOGGER.debug("After normal merge, min_heap size: {}".format(min_heap.size))

        non_mixture_num = 0
        small_size_num = 0
        for i, bucket in bucket_dict.items():
            if not bucket.is_mixed:
                non_mixture_num += 1
            if bucket.total_count < min_item_num:
                small_size_num += 1
        bucket_res = list(bucket_dict.values())
        bucket_res = sorted(bucket_res, key=lambda bucket: bucket.left_bound)
        LOGGER.debug("Before return merge_optimal_binning, non_mixture_num: {}, small_size_num: {},"
                     "min_heap size: {}".format(non_mixture_num, small_size_num, min_heap.size))

        LOGGER.debug("Before return, dick length: {}".format(len(bucket_dict)))
        LOGGER.info(f"Consume time: {time.time() - t0}")
        return bucket_res, non_mixture_num, small_size_num
Beispiel #3
0
    def memory_fit(self):

        """
        fitting using memory backend
        """

        LOGGER.info('begin to fit h**o decision tree, epoch {}, tree idx {},'
                    'running on memory backend'.format(self.epoch_idx, self.tree_idx))

        self.init_root_node_and_gh_sum()
        g, h = self.get_g_h_arr()
        self.init_memory_hist_builder(g, h, self.arr_bin_data, self.bin_num + self.use_missing) # last missing bin
        root_indices = self.init_node2index(len(self.arr_bin_data))
        self.cur_layer_node[0].inst_indices = root_indices  # root node

        tree_height = self.max_depth + 1  # non-leaf node height + 1 layer leaf

        for dep in range(tree_height):

            if dep + 1 == tree_height:
                for node in self.cur_layer_node:
                    node.is_leaf = True
                    self.tree_node.append(node)
                break

            self.sync_cur_layer_node_num(len(self.cur_layer_node), suffix=(dep, self.epoch_idx, self.tree_idx))

            node_map = self.get_node_map(self.cur_layer_node)
            node_hists = []
            for batch_id, i in enumerate(range(0, len(self.cur_layer_node), self.max_split_nodes)):

                cur_to_split = self.cur_layer_node[i:i + self.max_split_nodes]

                for node in cur_to_split:
                    if node.id in node_map:
                        hist = self.sklearn_compute_agg_hist(node.inst_indices)
                        hist_bag = HistogramBag(hist, tensor_type='array')
                        hist_bag.hid = node.id
                        hist_bag.p_hid = node.parent_nodeid
                        node_hists.append(hist_bag)

                self.sync_local_node_histogram(node_hists, suffix=(batch_id, dep, self.epoch_idx, self.tree_idx))
                node_hists = []

            split_info = self.sync_best_splits(suffix=(dep, self.epoch_idx))
            new_layer_node = self.update_tree(self.cur_layer_node, split_info)
            node2inst_idx = []

            for node in self.cur_layer_node:
                if node.is_leaf:
                    continue
                l, r = self.assign_arr_inst(node, self.arr_bin_data, node.inst_indices, missing_bin_index=self.bin_num)
                node2inst_idx.append(l)
                node2inst_idx.append(r)
            assert len(node2inst_idx) == len(new_layer_node)

            for node, indices in zip(new_layer_node, node2inst_idx):
                node.inst_indices = indices
            self.cur_layer_node = new_layer_node

        sample_indices, weights = [], []

        for node in self.tree_node:
            if node.is_leaf:
                sample_indices += list(node.inst_indices)
                weights += [node.weight] * len(node.inst_indices)
            else:
                node.bid = self.bin_split_points[node.fid][int(node.bid)]

        # post-processing of memory backend fit
        sample_id = self.sample_id_arr[sample_indices]
        self.leaf_count = {}
        for node in self.tree_node:
            if node.is_leaf:
                self.leaf_count[node.id] = len(node.inst_indices)
        LOGGER.debug('leaf count is {}'.format(self.leaf_count))
        sample_id_type = type(self.g_h.take(1)[0][0])
        self.sample_weights = session.parallelize([(sample_id_type(id_), weight) for id_, weight in zip(sample_id, weights)],
                                                  include_key=True, partition=self.data_bin.partitions)
Beispiel #4
0
 def set_transfer_variable(self):
     if self.transfer_variable is not None:
         LOGGER.debug(
             "set flowid to transfer_variable, flowid: {}".format(self.flowid)
         )
         self.transfer_variable.set_flowid(self.flowid)
Beispiel #5
0
    def init_bucket(self, data_instances):
        header = data_overview.get_header(data_instances)
        self._default_setting(header)

        init_bucket_param = copy.deepcopy(self.params)
        init_bucket_param.bin_num = self.optimal_param.init_bin_nums
        if self.optimal_param.init_bucket_method == consts.QUANTILE:
            init_binning_obj = QuantileBinningTool(param_obj=init_bucket_param, allow_duplicate=False)
        else:
            init_binning_obj = BucketBinning(params=init_bucket_param)
        init_binning_obj.set_bin_inner_param(self.bin_inner_param)
        init_split_points = init_binning_obj.fit_split_points(data_instances)
        is_sparse = data_overview.is_sparse_data(data_instances)

        bucket_dict = dict()
        for col_name, sps in init_split_points.items():

            # bucket_list = [bucket_info.Bucket(idx, self.adjustment_factor, right_bound=sp)
            #                for idx, sp in enumerate(sps)]
            bucket_list = []
            for idx, sp in enumerate(sps):
                bucket = bucket_info.Bucket(idx, self.adjustment_factor, right_bound=sp)
                if idx == 0:
                    bucket.left_bound = -math.inf
                    bucket.set_left_neighbor(None)
                else:
                    bucket.left_bound = sps[idx - 1]
                bucket.event_total = self.event_total
                bucket.non_event_total = self.non_event_total
                bucket_list.append(bucket)
            bucket_list[-1].set_right_neighbor(None)
            bucket_dict[col_name] = bucket_list
            LOGGER.debug(f"col_name: {col_name}, length of sps: {len(sps)}, "
                         f"length of list: {len(bucket_list)}")

        # bucket_table = data_instances.mapPartitions2(convert_func)
        # bucket_table = bucket_table.reduce(self.merge_bucket_list, key_func=lambda key: key[1])
        from fate_arch.common.versions import get_eggroll_version
        version = get_eggroll_version()
        if version.startswith('2.0'):
            convert_func = functools.partial(self.convert_data_to_bucket_old,
                                             split_points=init_split_points,
                                             headers=self.header,
                                             bucket_dict=copy.deepcopy(bucket_dict),
                                             is_sparse=is_sparse,
                                             get_bin_num_func=self.get_bin_num)
            summary_dict = data_instances.mapPartitions(convert_func, use_previous_behavior=False)
            # summary_dict = summary_dict.reduce(self.copy_merge, key_func=lambda key: key[1])
            from federatedml.util.reduce_by_key import reduce
            bucket_table = reduce(summary_dict, self.merge_bucket_list, key_func=lambda key: key[1])
        elif version.startswith('2.2'):
            convert_func = functools.partial(self.convert_data_to_bucket,
                                             split_points=init_split_points,
                                             headers=self.header,
                                             bucket_dict=copy.deepcopy(bucket_dict),
                                             is_sparse=is_sparse,
                                             get_bin_num_func=self.get_bin_num)
            bucket_table = data_instances.mapReducePartitions(convert_func, self.merge_bucket_list)
            bucket_table = dict(bucket_table.collect())
        else:
            raise RuntimeError(f"Cannot recognized eggroll version: {version}")

        for k, v in bucket_table.items():
            LOGGER.debug(f"[feature] {k}, length of list: {len(v)}")

        LOGGER.debug("bucket_table: {}, length: {}".format(type(bucket_table), len(bucket_table)))
        bucket_table = [(k, v) for k, v in bucket_table.items()]
        LOGGER.debug("bucket_table: {}, length: {}".format(type(bucket_table), len(bucket_table)))

        bucket_table = session.parallelize(bucket_table, include_key=True, partition=data_instances.partitions)

        return bucket_table
Beispiel #6
0
    def fit_binary(self, data_instances, validate_data):
        self._abnormal_detection(data_instances)
        self.check_abnormal_values(data_instances)
        self.check_abnormal_values(validate_data)
        # self.validation_strategy = self.init_validation_strategy(data_instances, validate_data)
        self.callback_list.on_train_begin(data_instances, validate_data)

        LOGGER.debug(f"MODEL_STEP Start fin_binary, data count: {data_instances.count()}")

        self.header = self.get_header(data_instances)
        self.cipher_operator = self.cipher.gen_paillier_cipher_operator()

        if self.transfer_variable.use_async.get(idx=0):
            LOGGER.debug(f"set_use_async")
            self.gradient_loss_operator.set_use_async()

        self.batch_generator.initialize_batch_generator(data_instances)
        self.gradient_loss_operator.set_total_batch_nums(self.batch_generator.batch_nums)

        self.encrypted_calculator = [EncryptModeCalculator(self.cipher_operator,
                                                           self.encrypted_mode_calculator_param.mode,
                                                           self.encrypted_mode_calculator_param.re_encrypted_rate) for _
                                     in range(self.batch_generator.batch_nums)]

        LOGGER.info("Start initialize model.")
        model_shape = self.get_features_shape(data_instances)
        if self.init_param_obj.fit_intercept:
            self.init_param_obj.fit_intercept = False

        if not self.component_properties.is_warm_start:
            w = self.initializer.init_model(model_shape, init_params=self.init_param_obj)
            self.model_weights = LinearModelWeights(w, fit_intercept=self.init_param_obj.fit_intercept)
        else:
            self.callback_warm_start_init_iter(self.n_iter_)

        while self.n_iter_ < self.max_iter:
            self.callback_list.on_epoch_begin(self.n_iter_)

            LOGGER.info("iter:" + str(self.n_iter_))
            batch_data_generator = self.batch_generator.generate_batch_data()
            batch_index = 0
            self.optimizer.set_iters(self.n_iter_)
            for batch_data in batch_data_generator:
                # transforms features of raw input 'batch_data_inst' into more representative features 'batch_feat_inst'
                batch_feat_inst = batch_data
                # LOGGER.debug(f"MODEL_STEP In Batch {batch_index}, batch data count: {batch_feat_inst.count()}")

                optim_host_gradient = self.gradient_loss_operator.compute_gradient_procedure(
                    batch_feat_inst, self.encrypted_calculator, self.model_weights, self.optimizer, self.n_iter_,
                    batch_index)
                # LOGGER.debug('optim_host_gradient: {}'.format(optim_host_gradient))

                self.gradient_loss_operator.compute_loss(self.model_weights, self.optimizer,
                                                         self.n_iter_, batch_index, self.cipher_operator)

                self.model_weights = self.optimizer.update_model(self.model_weights, optim_host_gradient)
                batch_index += 1

            self.is_converged = self.converge_procedure.sync_converge_info(suffix=(self.n_iter_,))

            LOGGER.info("Get is_converged flag from arbiter:{}".format(self.is_converged))
            LOGGER.info("iter: {}, is_converged: {}".format(self.n_iter_, self.is_converged))
            LOGGER.debug(f"flowid: {self.flowid}, step_index: {self.n_iter_}")

            self.callback_list.on_epoch_end(self.n_iter_)
            self.n_iter_ += 1
            if self.stop_training:
                break

            if self.is_converged:
                break
        self.callback_list.on_train_end()
        self.set_summary(self.get_model_summary())
Beispiel #7
0
    def fit(self, data_instances, validate_data=None):
        """
        Train poisson regression model of role host
        Parameters
        ----------
        data_instances: DTable of Instance, input data
        """

        LOGGER.info("Enter hetero_poisson host")
        self._abnormal_detection(data_instances)

        self.validation_strategy = self.init_validation_strategy(
            data_instances, validate_data)

        self.header = self.get_header(data_instances)
        self.cipher_operator = self.cipher.gen_paillier_cipher_operator()

        self.batch_generator.initialize_batch_generator(data_instances)

        self.encrypted_calculator = [
            EncryptModeCalculator(
                self.cipher_operator,
                self.encrypted_mode_calculator_param.mode,
                self.encrypted_mode_calculator_param.re_encrypted_rate)
            for _ in range(self.batch_generator.batch_nums)
        ]

        LOGGER.info("Start initialize model.")
        model_shape = self.get_features_shape(data_instances)
        if self.init_param_obj.fit_intercept:
            self.init_param_obj.fit_intercept = False
        w = self.initializer.init_model(model_shape,
                                        init_params=self.init_param_obj)
        self.model_weights = LinearModelWeights(
            w, fit_intercept=self.fit_intercept)

        while self.n_iter_ < self.max_iter:
            LOGGER.info("iter:" + str(self.n_iter_))

            batch_data_generator = self.batch_generator.generate_batch_data()
            self.optimizer.set_iters(self.n_iter_)

            batch_index = 0
            for batch_data in batch_data_generator:
                batch_feat_inst = self.transform(batch_data)
                optim_host_gradient = self.gradient_loss_operator.compute_gradient_procedure(
                    batch_feat_inst, self.encrypted_calculator,
                    self.model_weights, self.optimizer, self.n_iter_,
                    batch_index)

                self.gradient_loss_operator.compute_loss(
                    batch_feat_inst, self.model_weights,
                    self.encrypted_calculator, self.optimizer, self.n_iter_,
                    batch_index, self.cipher_operator)

                self.model_weights = self.optimizer.update_model(
                    self.model_weights, optim_host_gradient)
                batch_index += 1

            self.is_converged = self.converge_procedure.sync_converge_info(
                suffix=(self.n_iter_, ))

            LOGGER.info("Get is_converged flag from arbiter:{}".format(
                self.is_converged))

            if self.validation_strategy:
                LOGGER.debug('Poisson host running validation')
                self.validation_strategy.validate(self, self.n_iter_)
                if self.validation_strategy.need_stop():
                    LOGGER.debug('early stopping triggered')
                    break

            self.n_iter_ += 1
            LOGGER.info("iter: {}, is_converged: {}".format(
                self.n_iter_, self.is_converged))
            if self.is_converged:
                break

        if not self.is_converged:
            LOGGER.info("Reach max iter {}, train model finish!".format(
                self.max_iter))

        if self.validation_strategy and self.validation_strategy.has_saved_best_model(
        ):
            self.load_model(self.validation_strategy.cur_best_model)
        self.set_summary(self.get_model_summary())
Beispiel #8
0
    def set_model_param(self, model_param):

        self.nn.restore_model(model_param.model_bytes)
        self.store_header = list(model_param.header)
        LOGGER.debug('stored header load, is {}'.format(self.store_header))
Beispiel #9
0
 def debug_data_inst(data_inst):
     collect_data = list(data_inst.collect())
     LOGGER.debug('showing DTable')
     for d in collect_data:
         LOGGER.debug('key {} id {}, features {} label {}'.format(d[0], d[1].inst_id, d[1].features, d[1].label))
Beispiel #10
0
    def predict(self, data_inst, ret_format='std'):

        # standard format, leaf indices, raw score
        assert ret_format in ['std', 'leaf', 'raw'], 'illegal ret format'

        LOGGER.info('running prediction')
        cache_dataset_key = self.predict_data_cache.get_data_key(data_inst)

        processed_data = self.data_and_header_alignment(data_inst)

        last_round = self.predict_data_cache.predict_data_last_round(
            cache_dataset_key)

        self.sync_predict_round(last_round)

        rounds = len(self.boosting_model_list) // self.booster_dim
        trees = []
        LOGGER.debug(
            'round involved in prediction {}, last round is {}, data key {}'.
            format(list(range(last_round, rounds)), last_round,
                   cache_dataset_key))

        for idx in range(last_round, rounds):
            for booster_idx in range(self.booster_dim):
                tree = self.load_booster(
                    self.booster_meta,
                    self.boosting_model_list[idx * self.booster_dim +
                                             booster_idx], idx, booster_idx)
                trees.append(tree)

        predict_cache = None
        tree_num = len(trees)

        if last_round != 0:
            predict_cache = self.predict_data_cache.predict_data_at(
                cache_dataset_key, min(rounds, last_round))
            LOGGER.info('load predict cache of round {}'.format(
                min(rounds, last_round)))

        if tree_num == 0 and predict_cache is not None and not (ret_format
                                                                == 'leaf'):
            return self.score_to_predict_result(data_inst, predict_cache)

        predict_rs = self.boosting_fast_predict(
            processed_data,
            trees=trees,
            predict_cache=predict_cache,
            pred_leaf=(ret_format == 'leaf'))

        if ret_format == 'leaf':
            return predict_rs  # predict result is leaf position

        self.predict_data_cache.add_data(cache_dataset_key,
                                         predict_rs,
                                         cur_boosting_round=rounds)
        LOGGER.debug('adding predict rs {}'.format(predict_rs))
        LOGGER.debug('last round is {}'.format(
            self.predict_data_cache.predict_data_last_round(
                cache_dataset_key)))

        if ret_format == 'raw':
            return predict_rs
        else:
            return self.score_to_predict_result(data_inst, predict_rs)
Beispiel #11
0
    def fit(self, data_inst, validate_data=None):
        self.validation_strategy = self.init_validation_strategy(
            data_inst, validate_data)
        self._build_model()
        self.prepare_batch_data(self.batch_generator, data_inst)
        if not self.input_shape:
            self.model.set_empty()

        self._set_loss_callback_info()
        cur_epoch = 0
        while cur_epoch < self.epochs:
            LOGGER.debug("cur epoch is {}".format(cur_epoch))
            epoch_loss = 0

            for batch_idx in range(len(self.data_x)):
                self.model.train(self.data_x[batch_idx],
                                 self.data_y[batch_idx], cur_epoch, batch_idx)

                self.reset_flowid()
                metrics = self.model.evaluate(self.data_x[batch_idx],
                                              self.data_y[batch_idx],
                                              cur_epoch, batch_idx)
                self.recovery_flowid()

                LOGGER.debug("metrics is {}".format(metrics))
                batch_loss = metrics["loss"]

                epoch_loss += batch_loss

            epoch_loss /= len(self.data_x)

            LOGGER.debug("epoch {}' loss is {}".format(cur_epoch, epoch_loss))

            self.callback_metric("loss", "train",
                                 [Metric(cur_epoch, epoch_loss)])

            self.history_loss.append(epoch_loss)

            if self.validation_strategy:
                self.validation_strategy.validate(self, cur_epoch)
                if self.validation_strategy.need_stop():
                    LOGGER.debug('early stopping triggered')
                    break

            is_converge = self.converge_func.is_converge(epoch_loss)
            self._summary_buf["is_converged"] = is_converge
            self.transfer_variable.is_converge.remote(is_converge,
                                                      role=consts.HOST,
                                                      idx=0,
                                                      suffix=(cur_epoch, ))

            if is_converge:
                LOGGER.debug(
                    "Training process is converged in epoch {}".format(
                        cur_epoch))
                break

            cur_epoch += 1

        if cur_epoch == self.epochs:
            LOGGER.debug(
                "Training process reach max training epochs {} and not converged"
                .format(self.epochs))

        if self.validation_strategy and self.validation_strategy.has_saved_best_model(
        ):
            self.load_model(self.validation_strategy.cur_best_model)

        self.set_summary(self._get_model_summary())
Beispiel #12
0
    def fit(self, data_instances):
        """
        Apply binning method for both data instances in local party as well as the other one. Afterwards, calculate
        the specific metric value for specific columns. Currently, iv is support for binary labeled data only.
        """
        LOGGER.info("Start feature binning fit and transform")
        self._abnormal_detection(data_instances)

        # self._parse_cols(data_instances)
        self._setup_bin_inner_param(data_instances, self.model_param)

        self.binning_obj.fit_split_points(data_instances)
        if self.model_param.skip_static:
            self.transform(data_instances)
            return self.data_output

        label_counts = data_overview.count_labels(data_instances)
        if label_counts > 2:
            raise ValueError(
                "Iv calculation support binary-data only in this version.")

        data_instances = data_instances.mapValues(self.load_data)
        self.set_schema(data_instances)
        label_table = data_instances.mapValues(lambda x: x.label)

        if self.model_param.local_only:
            LOGGER.info("This is a local only binning fit")
            self.binning_obj.cal_local_iv(data_instances,
                                          label_table=label_table)
            self.transform(data_instances)
            self.set_summary(self.binning_obj.bin_results.summary())
            LOGGER.debug(f"Summary is: {self.summary()}")
            return self.data_output

        cipher = PaillierEncrypt()
        cipher.generate_key()

        f = functools.partial(self.encrypt, cipher=cipher)
        encrypted_label_table = label_table.mapValues(f)

        self.transfer_variable.encrypted_label.remote(encrypted_label_table,
                                                      role=consts.HOST,
                                                      idx=-1)
        LOGGER.info("Sent encrypted_label_table to host")

        self.binning_obj.cal_local_iv(data_instances, label_table=label_table)

        encrypted_bin_infos = self.transfer_variable.encrypted_bin_sum.get(
            idx=-1)
        # LOGGER.debug("encrypted_bin_sums: {}".format(encrypted_bin_sums))

        total_summary = self.binning_obj.bin_results.summary()

        LOGGER.info("Get encrypted_bin_sum from host")
        for host_idx, encrypted_bin_info in enumerate(encrypted_bin_infos):
            host_party_id = self.component_properties.host_party_idlist[
                host_idx]
            encrypted_bin_sum = encrypted_bin_info['encrypted_bin_sum']
            host_bin_methods = encrypted_bin_info['bin_method']
            category_names = encrypted_bin_info['category_names']
            result_counts = self.__decrypt_bin_sum(encrypted_bin_sum, cipher)
            LOGGER.debug(
                "Received host {} result, length of buckets: {}".format(
                    host_idx, len(result_counts)))
            LOGGER.debug("category_name: {}, host_bin_methods: {}".format(
                category_names, host_bin_methods))
            # if self.model_param.method == consts.OPTIMAL:
            if host_bin_methods == consts.OPTIMAL:
                optimal_binning_params = encrypted_bin_info['optimal_params']

                host_model_params = copy.deepcopy(self.model_param)
                host_model_params.bin_num = optimal_binning_params.get(
                    'bin_num')
                host_model_params.optimal_binning_param.metric_method = optimal_binning_params.get(
                    'metric_method')
                host_model_params.optimal_binning_param.mixture = optimal_binning_params.get(
                    'mixture')
                host_model_params.optimal_binning_param.max_bin_pct = optimal_binning_params.get(
                    'max_bin_pct')
                host_model_params.optimal_binning_param.min_bin_pct = optimal_binning_params.get(
                    'min_bin_pct')

                self.binning_obj.event_total, self.binning_obj.non_event_total = self.get_histogram(
                    data_instances)
                optimal_binning_cols = {
                    x: y
                    for x, y in result_counts.items()
                    if x not in category_names
                }
                host_binning_obj = self.optimal_binning_sync(
                    optimal_binning_cols, data_instances.count(),
                    data_instances.partitions, host_idx, host_model_params)
                category_bins = {
                    x: y
                    for x, y in result_counts.items() if x in category_names
                }
                host_binning_obj.cal_iv_woe(category_bins,
                                            self.model_param.adjustment_factor)
            else:
                host_binning_obj = BaseBinning()
                host_binning_obj.cal_iv_woe(result_counts,
                                            self.model_param.adjustment_factor)
            host_binning_obj.set_role_party(role=consts.HOST,
                                            party_id=host_party_id)
            total_summary = self._merge_summary(
                total_summary, host_binning_obj.bin_results.summary())
            self.host_results.append(host_binning_obj)

        self.set_schema(data_instances)
        self.transform(data_instances)
        LOGGER.info("Finish feature binning fit and transform")
        self.set_summary(total_summary)
        LOGGER.debug(f"Summary is: {self.summary()}")
        return self.data_output
Beispiel #13
0
    def fit(self, data_instances=None, validate_data=None):
        """
        Train linear model of role arbiter
        Parameters
        ----------
        data_instances: DTable of Instance, input data
        """

        LOGGER.info("Enter hetero linear model arbiter fit")

        self.cipher_operator = self.cipher.paillier_keygen(
            self.model_param.encrypt_param.key_length)
        self.batch_generator.initialize_batch_generator()
        self.gradient_loss_operator.set_total_batch_nums(
            self.batch_generator.batch_num)

        self.validation_strategy = self.init_validation_strategy(
            data_instances, validate_data)

        while self.n_iter_ < self.max_iter:
            iter_loss = None
            batch_data_generator = self.batch_generator.generate_batch_data()
            total_gradient = None
            self.optimizer.set_iters(self.n_iter_)
            for batch_index in batch_data_generator:
                # Compute and Transfer gradient info
                gradient = self.gradient_loss_operator.compute_gradient_procedure(
                    self.cipher_operator, self.optimizer, self.n_iter_,
                    batch_index)
                if total_gradient is None:
                    total_gradient = gradient
                else:
                    total_gradient = total_gradient + gradient
                training_info = {
                    "iteration": self.n_iter_,
                    "batch_index": batch_index
                }
                self.perform_subtasks(**training_info)

                loss_list = self.gradient_loss_operator.compute_loss(
                    self.cipher_operator, self.n_iter_, batch_index)

                if len(loss_list) == 1:
                    if iter_loss is None:
                        iter_loss = loss_list[0]
                    else:
                        iter_loss += loss_list[0]
                        # LOGGER.info("Get loss from guest:{}".format(de_loss))

            # if converge
            if iter_loss is not None:
                iter_loss /= self.batch_generator.batch_num
                if self.need_call_back_loss:
                    self.callback_loss(self.n_iter_, iter_loss)
                self.loss_history.append(iter_loss)

            if self.model_param.early_stop == 'weight_diff':
                # LOGGER.debug("total_gradient: {}".format(total_gradient))
                weight_diff = fate_operator.norm(total_gradient)
                # LOGGER.info("iter: {}, weight_diff:{}, is_converged: {}".format(self.n_iter_,
                #                                                                 weight_diff, self.is_converged))
                if weight_diff < self.model_param.tol:
                    self.is_converged = True
            else:
                if iter_loss is None:
                    raise ValueError(
                        "Multiple host situation, loss early stop function is not available."
                        "You should use 'weight_diff' instead")
                self.is_converged = self.converge_func.is_converge(iter_loss)
                LOGGER.info("iter: {},  loss:{}, is_converged: {}".format(
                    self.n_iter_, iter_loss, self.is_converged))

            self.converge_procedure.sync_converge_info(self.is_converged,
                                                       suffix=(self.n_iter_, ))

            if self.validation_strategy:
                LOGGER.debug('Linear Arbiter running validation')
                self.validation_strategy.validate(self, self.n_iter_)
                if self.validation_strategy.need_stop():
                    LOGGER.debug('early stopping triggered')
                    self.best_iteration = self.n_iter_
                    break

            self.n_iter_ += 1
            if self.is_converged:
                break
        summary = {
            "loss_history": self.loss_history,
            "is_converged": self.is_converged,
            "best_iteration": self.best_iteration
        }
        if self.validation_strategy and self.validation_strategy.has_saved_best_model(
        ):
            self.load_model(self.validation_strategy.cur_best_model)
        if self.loss_history is not None and len(self.loss_history) > 0:
            summary["best_iter_loss"] = self.loss_history[self.best_iteration]

        self.set_summary(summary)
        LOGGER.debug("finish running linear model arbiter")
Beispiel #14
0
    def fit_binary(self, data_instances, validate_data):
        self._abnormal_detection(data_instances)
        self.check_abnormal_values(data_instances)
        self.check_abnormal_values(validate_data)
        self.validation_strategy = self.init_validation_strategy(
            data_instances, validate_data)
        LOGGER.debug(
            f"MODEL_STEP Start fin_binary, data count: {data_instances.count()}"
        )

        self.header = self.get_header(data_instances)
        self.cipher_operator = self.cipher.gen_paillier_cipher_operator()

        self.batch_generator.initialize_batch_generator(data_instances)
        self.gradient_loss_operator.set_total_batch_nums(
            self.batch_generator.batch_nums)

        self.encrypted_calculator = [
            EncryptModeCalculator(
                self.cipher_operator,
                self.encrypted_mode_calculator_param.mode,
                self.encrypted_mode_calculator_param.re_encrypted_rate)
            for _ in range(self.batch_generator.batch_nums)
        ]

        LOGGER.info("Start initialize model.")
        model_shape = self.get_features_shape(data_instances)
        if self.init_param_obj.fit_intercept:
            self.init_param_obj.fit_intercept = False
        w = self.initializer.init_model(model_shape,
                                        init_params=self.init_param_obj)
        # LOGGER.debug("model_shape: {}, w shape: {}, w: {}".format(model_shape, w.shape, w))
        self.model_weights = LinearModelWeights(
            w, fit_intercept=self.init_param_obj.fit_intercept)

        while self.n_iter_ < self.max_iter:
            LOGGER.info("iter:" + str(self.n_iter_))
            batch_data_generator = self.batch_generator.generate_batch_data()
            batch_index = 0
            self.optimizer.set_iters(self.n_iter_)
            for batch_data in batch_data_generator:
                # transforms features of raw input 'batch_data_inst' into more representative features 'batch_feat_inst'
                batch_feat_inst = self.transform(batch_data)
                LOGGER.debug(
                    f"MODEL_STEP In Batch {batch_index}, batch data count: {batch_feat_inst.count()}"
                )

                optim_host_gradient, fore_gradient = self.gradient_loss_operator.compute_gradient_procedure(
                    batch_feat_inst, self.encrypted_calculator,
                    self.model_weights, self.optimizer, self.n_iter_,
                    batch_index)
                # LOGGER.debug('optim_host_gradient: {}'.format(optim_host_gradient))

                training_info = {
                    "iteration": self.n_iter_,
                    "batch_index": batch_index
                }
                self.update_local_model(fore_gradient, data_instances,
                                        self.model_weights.coef_,
                                        **training_info)

                self.gradient_loss_operator.compute_loss(
                    self.model_weights, self.optimizer, self.n_iter_,
                    batch_index, self.cipher_operator)

                self.model_weights = self.optimizer.update_model(
                    self.model_weights, optim_host_gradient)
                batch_index += 1

            self.is_converged = self.converge_procedure.sync_converge_info(
                suffix=(self.n_iter_, ))

            LOGGER.info("Get is_converged flag from arbiter:{}".format(
                self.is_converged))

            if self.validation_strategy:
                LOGGER.debug('LR host running validation')
                self.validation_strategy.validate(self, self.n_iter_)
                if self.validation_strategy.need_stop():
                    LOGGER.debug('early stopping triggered')
                    break
            self.n_iter_ += 1
            LOGGER.info("iter: {}, is_converged: {}".format(
                self.n_iter_, self.is_converged))
            if self.is_converged:
                break
        if self.validation_strategy and self.validation_strategy.has_saved_best_model(
        ):
            self.load_model(self.validation_strategy.cur_best_model)
        self.set_summary(self.get_model_summary())
Beispiel #15
0
 def sync_local_node_histogram(self, acc_histogram: List[HistogramBag], suffix):
     # sending local histogram
     self.aggregator.send_histogram(acc_histogram, suffix=suffix)
     LOGGER.debug('local histogram sent at layer {}'.format(suffix[0]))
 def one_vs_rest_fit(self, train_data=None, validate_data=None):
     LOGGER.debug("Class num larger than 2, need to do one_vs_rest")
     self.one_vs_rest_obj.fit(data_instances=train_data,
                              validate_data=validate_data)
Beispiel #17
0
 def _sync_class_host(self, class_set):
     LOGGER.debug("Start to get aggregate classes")
     class_nums = self.transfer_variable.aggregate_classes.get(idx=0)
     self.classes = [x for x in range(class_nums)]
Beispiel #18
0
 def __add__(self, other):
     LOGGER.debug("In binary_op0, _w: {}".format(self._weights))
     return self.binary_op(other, operator.add, inplace=False)
Beispiel #19
0
    def fit(self, data_inst, validate_data=None):

        LOGGER.debug('in training, partitions is {}'.format(
            data_inst.partitions))
        LOGGER.info('start to fit a ftl model, '
                    'run mode is {},'
                    'communication efficient mode is {}'.format(
                        self.mode, self.comm_eff))

        self.check_host_number()

        data_loader, self.x_shape, self.data_num, self.overlap_num = self.prepare_data(
            self.init_intersect_obj(), data_inst, guest_side=True)
        self.input_dim = self.x_shape[0]

        # cache data_loader for faster validation
        self.cache_dataloader[self.get_dataset_key(data_inst)] = data_loader

        self.partitions = data_inst.partitions
        LOGGER.debug('self partitions is {}'.format(self.partitions))

        self.initialize_nn(input_shape=self.x_shape)
        self.feat_dim = self.nn._model.output_shape[1]
        self.constant_k = 1 / self.feat_dim
        self.validation_strategy = self.init_validation_strategy(
            data_inst, validate_data)

        self.callback_meta(
            "loss", "train",
            MetricMeta(name="train",
                       metric_type="LOSS",
                       extra_metas={"unit_name": "iters"}))

        # compute intermediate result of first epoch
        self.phi, self.phi_product, self.overlap_ua, self.send_components = self.batch_compute_components(
            data_loader)

        for epoch_idx in range(self.epochs):

            LOGGER.debug('fitting epoch {}'.format(epoch_idx))

            host_components = self.exchange_components(self.send_components,
                                                       epoch_idx=epoch_idx)

            loss = None

            for local_round_idx in range(self.local_round):

                if self.comm_eff:
                    LOGGER.debug(
                        'running local iter {}'.format(local_round_idx))

                grads = self.compute_backward_gradients(
                    host_components,
                    data_loader,
                    epoch_idx=epoch_idx,
                    local_round=local_round_idx)
                self.update_nn_weights(grads,
                                       data_loader,
                                       epoch_idx,
                                       decay=self.comm_eff)

                if local_round_idx == 0:
                    loss = self.compute_loss(
                        host_components, epoch_idx,
                        len(data_loader.get_overlap_indexes()))

                if local_round_idx + 1 != self.local_round:
                    self.phi, self.overlap_ua = self.compute_phi_and_overlap_ua(
                        data_loader)

            self.callback_metric("loss", "train", [Metric(epoch_idx, loss)])
            self.history_loss.append(loss)

            # updating variables for next epochs
            if epoch_idx + 1 == self.epochs:
                # only need to update phi in last epochs
                self.phi, _ = self.compute_phi_and_overlap_ua(data_loader)
            else:
                # compute phi, phi_product, overlap_ua etc. for next epoch
                self.phi, self.phi_product, self.overlap_ua, self.send_components = self.batch_compute_components(
                    data_loader)

            # check early_stopping_rounds
            if self.validation_strategy is not None:
                self.validation_strategy.validate(self, epoch_idx)
                if self.validation_strategy.need_stop():
                    LOGGER.debug('early stopping triggered')
                    break

            # check n_iter_no_change
            if self.n_iter_no_change is True:
                if self.check_convergence(loss):
                    self.sync_stop_flag(epoch_idx, stop_flag=True)
                    break
                else:
                    self.sync_stop_flag(epoch_idx, stop_flag=False)

            LOGGER.debug('fitting epoch {} done, loss is {}'.format(
                epoch_idx, loss))

        self.callback_meta(
            "loss", "train",
            MetricMeta(name="train",
                       metric_type="LOSS",
                       extra_metas={"Best": min(self.history_loss)}))

        self.set_summary(self.generate_summary())
        LOGGER.debug('fitting ftl model done')
    def preprocess(self):

        if self.multi_mode == consts.MULTI_OUTPUT:
            self.booster_dim = 1
            LOGGER.debug('multi mode tree dim reset to 1')