def store_model_in_demisto(model_name, model_override, train_text_data, train_tag_data, confusion_matrix): model = demisto_ml.train_text_classifier(train_text_data, train_tag_data, True) model_data = demisto_ml.encode_model(model) model_labels = demisto_ml.get_model_labels(model) res = demisto.executeCommand( 'createMLModel', { 'modelData': model_data, 'modelName': model_name, 'modelLabels': model_labels, 'modelOverride': model_override }) if is_error(res): return_error(get_error(res)) confusion_matrix_no_all = { k: v for k, v in confusion_matrix.items() if k != 'All' } confusion_matrix_no_all = { k: {sub_k: sub_v for sub_k, sub_v in v.items() if sub_k != 'All'} for k, v in confusion_matrix_no_all.items() } res = demisto.executeCommand('evaluateMLModel', { 'modelConfusionMatrix': confusion_matrix_no_all, 'modelName': model_name }) if is_error(res): return_error(get_error(res))
def store_model_in_demisto(model_name, model_override, X, y, confusion_matrix, threshold, y_test_true, y_test_pred, y_test_pred_prob): model = demisto_ml.train_text_classifier(X, y, True) model_data = demisto_ml.encode_model(model) model_labels = demisto_ml.get_model_labels(model) res = demisto.executeCommand('createMLModel', {'modelData': model_data, 'modelName': model_name, 'modelLabels': model_labels, 'modelOverride': model_override, 'modelExtraInfo': {'threshold': threshold} }) if is_error(res): return_error(get_error(res)) confusion_matrix_no_all = {k: v for k, v in confusion_matrix.items() if k != 'All'} confusion_matrix_no_all = {k: {sub_k: sub_v for sub_k, sub_v in v.items() if sub_k != 'All'} for k, v in confusion_matrix_no_all.items()} res = demisto.executeCommand('evaluateMLModel', {'modelConfusionMatrix': confusion_matrix_no_all, 'modelName': model_name, 'modelEvaluationVectors': {'Ypred': y_test_pred, 'Ytrue': y_test_true, 'YpredProb': y_test_pred_prob } }) if is_error(res): return_error(get_error(res))
def get_predictions_for_test_set(train_text_data, train_tag_data): X = pd.Series(train_text_data) y = pd.Series(train_tag_data) train_set_ratio = float(demisto.args()['trainSetRatio']) n_splits = int(1.0 / (1 - train_set_ratio)) skf = StratifiedKFold(n_splits=n_splits, shuffle=False, random_state=None) skf.get_n_splits(X, y) train_index, test_index = list(skf.split(X, y))[-1] X_train, X_test = list(X[train_index]), list(X[test_index]) y_train, y_test = list(y[train_index]), list(y[test_index]) model = demisto_ml.train_text_classifier(X_train, y_train) ft_test_predictions = demisto_ml.predict(model, X_test) y_pred = [{y_tuple[0]: y_tuple[1]} for y_tuple in ft_test_predictions] return y_test, y_pred
def main(): input = demisto.args()['input'] input_type = demisto.args()['inputType'] model_name = demisto.args()['modelName'] store_model = demisto.args()['storeModel'] == 'true' model_override = demisto.args().get('overrideExistingModel', 'false') == 'true' target_accuracy = float(demisto.args()['targetAccuracy']) text_field = demisto.args()['textField'] tag_fields = demisto.args()['tagField'].split(",") labels_mapping = get_phishing_map_labels(demisto.args()['phishingLabels']) keyword_min_score = float(demisto.args()['keywordMinScore']) return_predictions_on_test_set = demisto.args().get( 'returnPredictionsOnTestSet', 'false') == 'true' original_text_fields = demisto.args().get('originalTextFields', '') if input_type.endswith("filename"): data = read_files_by_name(input, input_type.split("_")[0].strip()) else: data = read_file(input, input_type) demisto.results(len(data)) if len(data) == 0: err = ['No incidents were received.'] err += [ 'Make sure that all arguments are set correctly and that incidents exist in the environment.' ] return_error(' '.join(err)) if len(data) < MIN_INCIDENTS_THRESHOLD: err = ['Only {} incident(s) were received.'.format(len(data))] err += [ 'Minimum number of incidents per label required for training is {}.' .format(MIN_INCIDENTS_THRESHOLD) ] err += [ 'Make sure that all arguments are set correctly and that enough incidents exist in the environment.' ] return_error('\n'.join(err)) data = set_tag_field(data, tag_fields) data, exist_labels_counter, missing_labels_counter = get_data_with_mapped_label( data, labels_mapping, DBOT_TAG_FIELD) validate_data_and_labels(data, exist_labels_counter, labels_mapping, missing_labels_counter) # print important words for each category find_keywords_bool = 'findKeywords' in demisto.args() and demisto.args( )['findKeywords'] == 'true' if find_keywords_bool: try: find_keywords(data, DBOT_TAG_FIELD, text_field, keyword_min_score) except Exception: pass X, y = get_X_and_y_from_data(data, text_field) test_index, train_index = get_train_and_test_sets_indices(X, y) X_train, X_test = [X[i] for i in train_index], [X[i] for i in test_index] y_train, y_test = [y[i] for i in train_index], [y[i] for i in test_index] model = demisto_ml.train_text_classifier(X_train, y_train) ft_test_predictions = demisto_ml.predict(model, X_test) y_pred = [{y_tuple[0]: y_tuple[1]} for y_tuple in ft_test_predictions] if return_predictions_on_test_set: return_file_result_with_predictions_on_test_set( data, original_text_fields, test_index, text_field, y_test, y_pred) if 'maxBelowThreshold' in demisto.args(): target_recall = 1 - float(demisto.args()['maxBelowThreshold']) else: target_recall = 0 [threshold_metrics_entry, per_class_entry] = get_ml_model_evaluation(y_test, y_pred, target_accuracy, target_recall, detailed=True) demisto.results(per_class_entry) # show results for the threshold found - last result so it will appear first confusion_matrix = output_model_evaluation( model_name=model_name, y_test=y_test, y_pred=y_pred, res=threshold_metrics_entry, context_field='DBotPhishingClassifier') if store_model: y_test_pred = [y_tuple[0] for y_tuple in ft_test_predictions] y_test_pred_prob = [y_tuple[1] for y_tuple in ft_test_predictions] threshold = float(threshold_metrics_entry['Contents']['threshold']) store_model_in_demisto(model_name, model_override, X, y, confusion_matrix, threshold, y_test_true=y_test, y_test_pred=y_test_pred, y_test_pred_prob=y_test_pred_prob) demisto.results( "Done training on {} samples model stored successfully".format( len(y))) else: demisto.results('Skip storing model')