def tree_generator_with_post_pruning(input_tree, data_set_train, data_set_validate, features_list, features_dict, is_features_discrete): """ 后剪枝 树生成 :param input_tree: :param data_set_train: :param data_set_validate: :param features_list: :param features_dict: :param is_features_discrete: :return: """ if type(input_tree).__name__ == 'str': return input_tree feature_name = list(input_tree.keys())[0] sub_tree = input_tree[feature_name] feature_index = features_list.index(feature_name) for value in sub_tree: value_key = get_keys_for_dict(features_dict[feature_name], value)[0] new_features_list = features_list[:] new_is_features_discrete = is_features_discrete[:] del (new_features_list[feature_index]) del (new_is_features_discrete[feature_index]) sub_data_set_train = split_data_set_by_operate(data_set_train, feature_index, value_key, operator.eq, delete_col=True) sub_data_set_validate = split_data_set_by_operate(data_set_validate, feature_index, value_key, operator.eq, delete_col=True) input_tree[feature_name][value] = tree_generator_with_post_pruning( sub_tree[value], sub_data_set_train, sub_data_set_validate, new_features_list, features_dict, new_is_features_discrete) # 到达某个结点之后,测试一下准确率,看看是否要剪枝 # get_validation_error_count_by_major_class(major_class, data_set_validate) # get_validation_error_count_by_tree(input_tree, data_set_validate, features_list) error_count_by_tree = get_validation_error_count_by_tree( input_tree, data_set_validate, features_list, features_dict, is_features_discrete) error_count_by_major_class = get_validation_error_count_by_major_class( get_most_common_class(data_set_train), data_set_validate) if error_count_by_tree <= error_count_by_major_class: return input_tree return get_most_common_class(data_set_train)
def tree_generate_without_pruning(data_set_train, features_list, features_dict, is_features_discrete): """ 树生成 不剪枝 :param data_set_train: :param features_list: :param features_dict: :param is_features_discrete: :return: """ # deep copy features_list = features_list[:] is_features_discrete = is_features_discrete[:] # 检查样本是否已经同属于一类了 class_list = [sample[-1] for sample in data_set_train] if class_list.count(class_list[0]) == len(class_list): return class_list[0] # 检查features_list是否为空, dataSet在feature上的取值都一样(所有样本在所有属性上的取值一样) if len(features_list) == 0 or is_all_sample_same(data_set_train): return get_most_common_class(data_set_train) # 从A中找出最优的属性值,进行划分 best_feature_index, best_continuous_feature_value = get_best_feature_gini( data_set_train, features_list, is_features_discrete) best_feature_name = features_list[best_feature_index] if is_features_discrete[best_feature_index] == 1: # 如果该特征是离散型的,从数据中删除该特征 del (features_list[best_feature_index]) del (is_features_discrete[best_feature_index]) tree = {best_feature_name: {}} feature_values_list = features_dict[best_feature_name].keys() for feature_value in feature_values_list: sub_data_set_train = split_data_set_by_operate(data_set_train, best_feature_index, feature_value, operator.eq, delete_col=True) feature_value_name = features_dict[best_feature_name][ feature_value] # 如果划分出来的子属性集合为空,则将分支结点标记为叶节点,其分类标记为data_set中样本最多的类 if len(sub_data_set_train) == 0: tree[best_feature_name][ feature_value_name] = get_most_common_class(data_set_train) # 如果划分出来的子属性集合不为空,则继续递归 else: tree[best_feature_name][feature_value_name] = \ tree_generate_without_pruning(sub_data_set_train, features_list, features_dict, is_features_discrete) else: # 如果该特征是连续的,不需要从数据中删除该特征 # 与离散属性不同,若当前结点划分属性为连续属性,该属性还可作为其后代结点的划分属性 key = best_feature_name + '<=' + str.format( "%0.3f" % best_continuous_feature_value) tree = {key: {}} sub_data_set_le = split_data_set_by_operate( data_set_train, best_feature_index, best_continuous_feature_value, operator.le, delete_col=False) sub_data_set_gt = split_data_set_by_operate( data_set_train, best_feature_index, best_continuous_feature_value, operator.gt, delete_col=False) tree[key]['是'] = tree_generate_without_pruning(sub_data_set_le, features_list, features_dict, is_features_discrete) tree[key]['否'] = tree_generate_without_pruning(sub_data_set_gt, features_list, features_dict, is_features_discrete) return tree
def tree_generate_with_random_feature_selection(data_set, features_list, features_dict, is_features_discrete): """ 通过随机选择特征来生成决策树 :param data_set: :param features_list: :param features_dict: :param is_features_discrete: :return: """ # deep copy features_list = features_list[:] is_features_discrete = is_features_discrete[:] # samples_class = [sample[-1] for sample in data_set] if samples_class.count(samples_class[0]) == len(samples_class): return samples_class[0] if len(features_list) == 0 or is_all_sample_same(data_set): return get_most_common_class(data_set) best_feature_index = random.randint(0, len(features_list) - 1) best_feature_name = features_list[best_feature_index] if is_features_discrete[best_feature_index] == 1: # 如果该特征是离散型的,从数据中删除该特征 del (features_list[best_feature_index]) del (is_features_discrete[best_feature_index]) tree = {best_feature_name: {}} feature_value_set = features_dict[best_feature_name].keys() for feature_value in feature_value_set: # 如果没有重新拷贝一份,只要每往下一层,就会删除features_list中的一个数据, # 但是递归返回时的往另外一个分支走的时候就会出问题 sub_data_set_train = split_data_set_by_operate(data_set, best_feature_index, feature_value, operator.eq, delete_col=True) feature_value_name = features_dict[best_feature_name][ feature_value] # 如果划分出来的子属性集合为空,则将分支结点标记为叶节点,其分类标记为data_set中样本最多的类 if len(sub_data_set_train) == 0: tree[best_feature_name][ feature_value_name] = get_most_common_class(data_set) # 如果划分出来的子属性集合不为空,则继续递归 else: tree[best_feature_name][feature_value_name] = \ tree_generate_with_random_feature_selection(sub_data_set_train, features_list, features_dict, is_features_discrete) else: # 如果该特征是连续的 feature_values_mid_value_list = [] feature_value_set = [sample[best_feature_index] for sample in data_set] feature_value_set = set(feature_value_set) feature_value_set = sorted(feature_value_set, reverse=False) for i in range(len(feature_value_set) - 1): feature_values_mid_value_list.append( (feature_value_set[i] + feature_value_set[i + 1]) / 2) # 从中值中随机选择一个数作为best_continuous_feature_value best_continuous_feature_value = random.choice( feature_values_mid_value_list) key = best_feature_name + '<=' + str.format( "%0.3f" % best_continuous_feature_value) tree = {key: {}} sub_data_set_le = split_data_set_by_operate( data_set, best_feature_index, best_continuous_feature_value, operator.le, delete_col=False) sub_data_set_gt = split_data_set_by_operate( data_set, best_feature_index, best_continuous_feature_value, operator.gt, delete_col=False) tree[key]['是'] = tree_generate_with_random_feature_selection( sub_data_set_le, features_list, features_dict, is_features_discrete) tree[key]['否'] = tree_generate_with_random_feature_selection( sub_data_set_gt, features_list, features_dict, is_features_discrete) return tree
def tree_generate_with_pre_pruning(data_set_train, data_set_validate, features_list, features_dict, is_features_discrete): """ 预剪枝 树生成 :param data_set_train: :param data_set_validate: :param features_list: :param features_dict: :param is_features_discrete: :return: """ # deep copy features_list = features_list[:] is_features_discrete = is_features_discrete[:] # 检查样本是否已经同属于一类了 class_list = [sample[-1] for sample in data_set_train] if class_list.count(class_list[0]) == len(class_list): return class_list[0] # 检查features_list是否为空, dataSet在feature上的取值都一样(所有样本在所有属性上的取值一样) if len(features_list) == 0 or is_all_sample_same(data_set_train): return get_most_common_class(data_set_train) # 从A中找出最优的属性值,进行划分 best_feature_index, best_continuous_feature_value = get_best_feature_gini( data_set_train, features_list, is_features_discrete) best_feature_name = features_list[best_feature_index] accuracy_rate_before_pruning = get_validation_error_before_pruning( data_set_train, data_set_validate) accuracy_rate_after_pruning = get_validation_error_after_pruning( data_set_train, data_set_validate, best_feature_index, is_features_discrete[best_feature_index], best_continuous_feature_value) # 划分之后的准确率没有超过划分之前的准确率,不再进行划分 if accuracy_rate_before_pruning >= accuracy_rate_after_pruning: most_common_class = get_most_common_class(data_set_train) return most_common_class # 划分之后的准确率超过了划分之前的准确率,继续进行划分 else: if is_features_discrete[best_feature_index] == 1: del (features_list[best_feature_index]) del (is_features_discrete[best_feature_index]) tree = {best_feature_name: {}} feature_values_list = [ sample[best_feature_index] for sample in data_set_train ] feature_values_list = set(feature_values_list) for feature_value in feature_values_list: sub_data_set_train = split_data_set_by_operate( data_set_train, best_feature_index, feature_value, operator.eq, delete_col=True) sub_data_set_validate = split_data_set_by_operate( data_set_validate, best_feature_index, feature_value, operator.eq, delete_col=True) feature_value_name = features_dict[best_feature_name][ feature_value] tree[best_feature_name][feature_value_name] = \ tree_generate_with_pre_pruning(sub_data_set_train, sub_data_set_validate, features_list, features_dict, is_features_discrete) else: key = best_feature_name + '<=' + str.format( "%0.3f" % best_continuous_feature_value) tree = {key: {}} # 划分数据集 sub_data_set_train_le = split_data_set_by_operate( data_set_train, best_feature_index, best_continuous_feature_value, operator.le, delete_col=False) sub_data_set_validate_le = split_data_set_by_operate( data_set_validate, best_feature_index, best_continuous_feature_value, operator.le, delete_col=False) sub_data_set_train_gt = split_data_set_by_operate( data_set_train, best_feature_index, best_continuous_feature_value, operator.gt, delete_col=False) sub_data_set_validate_gt = split_data_set_by_operate( data_set_validate, best_feature_index, best_continuous_feature_value, operator.gt, delete_col=False) # 生成节点 tree[key]['是'] = tree_generate_with_pre_pruning( sub_data_set_train_le, sub_data_set_validate_le, features_list, features_dict, is_features_discrete) tree[key]['否'] = tree_generate_with_pre_pruning( sub_data_set_train_gt, sub_data_set_validate_gt, features_list, features_dict, is_features_discrete) return tree
def tree_generate(data_set, features_list, features_dict, is_features_discrete, ID3_or_C45): # deep copy # 使用new features_list来替代features_list,传递到tree_generate函数,tree_generate会删除features_list中的内容, # 如果一直传递同一个features_list,会出现问题,new features_list = features_list[:],相当于拷贝一个新的features list features_list = features_list[:] is_features_discrete = is_features_discrete[:] # 检查样本是否已经同属于一类了 class_list = [sample[-1] for sample in data_set] if class_list.count(class_list[0]) == len(class_list): return class_list[0] # 检查features是否为空, dataSet在features上的取值都一样(所有样本在所有属性上的取值一样) if len(features_list) == 0 or is_all_sample_same(data_set): return get_most_common_class(data_set) # 从A中找出最优的属性值,进行划分 best_feature_index, best_continuous_feature_value = get_best_feature( data_set, features_list, is_features_discrete, ID3_or_C45) best_feature_name = features_list[best_feature_index] # 如果该特征是离散型的,从数据中删除该特征 if is_features_discrete[best_feature_index] == 1: tree = {best_feature_name: {}} del (is_features_discrete[best_feature_index]) del (features_list[best_feature_index]) feature_values_list = features_dict[best_feature_name].keys() for feature_value in feature_values_list: # 往下继续生成树 sub_data_set = split_data_set_by_operate(data_set, best_feature_index, feature_value, operator.eq, delete_col=True) feature_value_name = features_dict[best_feature_name][ feature_value] # 如果划分出来的子属性集合为空,则将分支结点标记为叶节点,其分类标记为data_set中样本最多的类 if len(sub_data_set) == 0: tree[best_feature_name][ feature_value_name] = get_most_common_class(data_set) # 如果划分出来的子属性集合不为空,则继续递归 else: tree[best_feature_name][feature_value_name] = \ tree_generate(sub_data_set, features_list, features_dict, is_features_discrete, ID3_or_C45) # 如果该特征是连续的,不需要从数据中删除该特征 # 与离散属性不同,若当前结点划分属性为连续属性,该属性还可作为其后代结点的划分属性 else: key = best_feature_name + '<=' + str.format( "%0.3f" % best_continuous_feature_value) tree = {key: {}} tree[key]['是'] = tree_generate( split_data_set_by_operate(data_set, best_feature_index, best_continuous_feature_value, operator.le, delete_col=False), features_list, features_dict, is_features_discrete, ID3_or_C45) tree[key]['否'] = tree_generate( split_data_set_by_operate(data_set, best_feature_index, best_continuous_feature_value, operator.gt, delete_col=False), features_list, features_dict, is_features_discrete, ID3_or_C45) return tree