def train_model(self, clf_type, features, labels, params=None, logging=False): skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=23) # params used for setting model after training if params is None: clf = self._get_classifier(clf_type) else: if clf_type == 'svm': clf = SVC(C=params['C'], gamma=params['Gamma'], kernel='rbf', decision_function_shape='ovr', probability=True) elif clf_type == 'lgr': clf = LogisticRegression(C=params['C'], solver='newton-cg') _C, _Gamma, _Scores, _Accuracies, _Precisions, _Recalls = np.array( []), np.array([]), np.array([]), np.array([]), np.array( []), np.array([]) selected_params = dict() _Probabilities, _Indicies, _True_Labels = list(), list(), list() for train_idx, test_idx in skf.split(features, labels): # if clf_type == 'xgb': # if logging: # watchlist = [(dfeatures, 'train')] # dfeatures = xgb.DMatrix( # features[train_idx], labels=labels[train_idx]) # bst = xgb.train(clf, dfeatures, watchlist, 5) # else: clf.fit(features[train_idx], labels[train_idx]) if params is None: if logging: Log.classifier_details(clf_type, clf) if clf_type == 'svm': _C = np.append(_C, clf.best_params_['C']) _Gamma = np.append(_Gamma, clf.best_params_['gamma']) _Scores = np.append(_Scores, clf.best_score_) else: predictions = clf.predict(features[test_idx]) _Probabilities.append(clf.predict_proba(features[test_idx])) _Indicies.append(test_idx) _True_Labels.append(labels[test_idx]) _Accuracies = np.append( _Accuracies, accuracy_score(predictions, labels[test_idx])) from sklearn.metrics import precision_score, recall_score _Precisions = np.append( _Precisions, precision_score(predictions, labels[test_idx], average='macro')) _Recalls = np.append( _Precisions, recall_score(predictions, labels[test_idx], average='macro')) print('Accuracy score: ', accuracy_score(predictions, labels[test_idx])) if logging and params is None: Log.cross_val_details(_C, _Gamma, _Scores) if params is None: selected_params = { 'C': np.dot(np.mean(_C), np.mean(_Scores).T), 'Gamma': np.dot(np.mean(_Gamma), np.mean(_Scores).T) } return _Accuracies, selected_params, _Precisions, _Recalls