def fit_predict(model_name, model_param, trn_samples, tst_samples): ''' 利用训练样本训练模型并对测试样本进行预测 Args: model_name (str): 模型名称 model_param (dict): 模型参数 trn_samples (Samples): 训练样本 tst_samples (Samples): 测试样本 ''' trn_Xs = trn_samples.get_Xs() trn_ys = trn_samples.get_ys() tst_Xs = tst_samples.get_Xs() trn_ys[0] = 1 # TODO clf = get_classifier(model_name, model_param) clf.fit(trn_Xs, trn_ys) tst_ys_pred = clf.predict(tst_Xs) tst_ys_pred[0] = 1 # TODO tst_ys_pred[1] = 1 # TODO for i, sample in enumerate(tst_samples): sample.set_y_pred(tst_ys_pred[i])
def get_x_importances(samples): ''' 计算特征重要性 Args: samples (Samples): 训练样本 Returns: dict: 特征重要性,格式为{特征1: 重要性1, 特征2: 重要性2, ...} ''' model_name = 'OLP.core.models.RandomForestClassifier' model_param = {'n_estimators': 100} clf = get_classifier(model_name, model_param) Xs = samples.get_Xs() ys = samples.get_ys() clf.fit(Xs, ys) x_indexes = samples.get_x_indexes() x_importances = {} x_importances_ = clf.get_x_importances() for x, index in x_indexes.iteritems(): x_importances[x] = x_importances_[index] return x_importances