def get_ann_tokens(sent):
    """Remove annotation artifacts"""
    if isinstance(sent, AnnotatedText):
        ann_tokens = AnnotatedTokens(AnnotatedText(sent.get_annotated_text()))
    elif isinstance(sent, str):
        ann_tokens = AnnotatedTokens(AnnotatedText(sent.rstrip('\n')))
    else:
        raise Exception("Incorrect input for normalization!")
    return ann_tokens
Beispiel #2
0
def process_line_with_clf(comb_data):
    clf, scaler, selector, skip_deps, line = comb_data
    ann_tokens = AnnotatedTokens(AnnotatedText(line))
    for ann in ann_tokens.iter_annotations():
        features, labels, ordered_features_names = get_x_and_y([ann.meta],
                                                               skip_deps)
        features_norm = scaler.transform(features)
        features_selected = selector.transform(features_norm)
        pred_probs = get_clf_pred_probs(clf, features_selected)
        score = pred_probs[0][1]
        ann.meta['clf_score'] = score
    return ann_tokens.get_annotated_text()
Beispiel #3
0
def preprocess_batch(batch, error_types, system_type):
    records = []
    ann_sents_dict = {}
    for i, sent in enumerate(batch):
        ann_sent = AnnotatedTokens(AnnotatedText(sent))
        ann_sents_dict[i] = ann_sent
        for ann in ann_sent.iter_annotations():
            ann.meta["system_type"] = system_type
            et = get_normalized_error_type(ann)
            if error_types is not None and et not in error_types:
                continue
            records.append([ann_sent, ann, (i, ann.start, ann.end)])
    return records, ann_sents_dict
def wrap_get_lm_scores(sent):
    ann_sent = ""
    if not sent:
        return sent
    while not ann_sent:
        try:
            ann_tokens = AnnotatedTokens(AnnotatedText(sent))
            for ann in ann_tokens.iter_annotations():
                scores = get_kenlm_scores(ann_tokens, [ann], add_original=True)
                confidence = scores[0] - scores[1]
                ann.meta['confidence'] = confidence
            ann_sent = ann_tokens.get_annotated_text()
        except Exception as e:
            print(
                f"Something wrong with KenLM. Exception {e}. Sentence {sent}")
    return ann_sent
Beispiel #5
0
def evaluate_with_m2(gold_annotations, output_annotations, tmp_filename):
    assert len(gold_annotations) == len(output_annotations)
    gold_ann_tokens = [AnnotatedTokens(AnnotatedText(anno_text)) for anno_text in gold_annotations]
    gold_m2_annotations = []
    for ann_tokens in gold_ann_tokens:
        try:
            converted = MultiAnnotatedSentence.from_annotated_tokens(ann_tokens).to_m2_str() + '\n'
            gold_m2_annotations.append(converted)
        except Exception as e:
            # print(e)
            # print(ann_tokens.get_original_text())
            # print(ann_tokens.get_annotated_text())
            # print(ann_tokens.get_annotated_text(with_meta=False))
            # new_ann_tokens = AnnotatedTokens(ann_tokens._tokens)
            for ann in ann_tokens.iter_annotations():
                if not ann.suggestions or str(ann.suggestions[0]) == "NO_SUGGESTIONS":
                    ann_tokens.remove(ann)
            new_converted = MultiAnnotatedSentence.from_annotated_tokens(ann_tokens).to_m2_str() + '\n'
            gold_m2_annotations.append(new_converted)

    output_corrected_texts = [AnnotatedText(anno_text).get_corrected_text() for anno_text in output_annotations]
    # Write as text files

    gold_file_processed = f"g_{os.path.basename(tmp_filename)}"
    sub_file_processed = f"o_{os.path.basename(tmp_filename)}"
    write_lines(gold_file_processed, gold_m2_annotations)
    write_lines(sub_file_processed, output_corrected_texts)
    # Run m2scorer (OFFICIAL VERSION 3.2, http://www.comp.nus.edu.sg/~nlp/conll14st.html)
    system(f'./m2scorer/m2scorer {sub_file_processed} {gold_file_processed}')
    remove_file(sub_file_processed)
    remove_file(gold_file_processed)
def combine_systems_output(sentence_list, strategy, discard_priorities):
    """Get one sentence which combined outputs of all systems"""
    all_anns = []
    tokens_list = []
    n_options = len(sentence_list)
    for sent in sentence_list:
        ann_tokens = get_ann_tokens(sent)
        tokens_list.append(ann_tokens._tokens)
        all_anns.extend([x for x in ann_tokens.iter_annotations()])
    paired_anns = split_annotations_on_disjoint_pairs(all_anns)
    tokens = tokens_list[0]
    ann_tokens = AnnotatedTokens(tokens)
    for i, comb_anns_list in enumerate(paired_anns):
        ann = resolve_annotation_conflicts(ann_tokens, comb_anns_list,
                                           strategy, n_options,
                                           discard_priorities)

        if ann and ann.suggestions and str(
                ann.suggestions[0]) != "NO_SUGGESTIONS":
            ann_tokens.annotate(ann.start, ann.end, ann.suggestions, ann.meta)
    return ann_tokens
Beispiel #7
0
def run_check_parallel(orig_list, check_type, error_type, n_threads, fn_out):
    if check_type == 'Patterns':
        combined_data = get_combined_data(orig_list, check_type)
    elif check_type == 'OPC-with-filters':
        filters = {"<ErrorTypesFilter(types=None)>": {"types": [error_type]}}
        combined_data = get_combined_data(orig_list,
                                          check_type,
                                          addr='PREPROD',
                                          filters=filters)
    elif check_type == 'OPC-without-filters':
        filters = False
        combined_data = get_combined_data(orig_list,
                                          check_type,
                                          addr='PREPROD',
                                          filters=filters)
    elif check_type == 'UPC5-high-precision':
        combined_data = get_combined_data(orig_list, check_type)
    elif check_type == 'UPC5-high-recall':
        upc_addr = "upc-high-recall-server.phantasm.gnlp.io:8081"
        combined_data = get_combined_data(orig_list,
                                          check_type,
                                          addr=upc_addr,
                                          custom_server=True)
    else:
        raise ValueError('Unknown check_type = %s' % check_type)

    # create helper object to deal with batches
    batcher = Batcher(combined_data, batch_size=n_threads, verbose=True)
    pool = Pool(processes=n_threads)  # pool to make multithreading
    result_anno = list()
    for batch in batcher.iter_batches():
        result_anno_batch = pool.map(wrapped_check_func, batch)
        result_anno.extend(result_anno_batch)
    pool.close()
    pool.join()
    # Normalizing trick
    normalized_result_anno = [
        AnnotatedTokens(AnnotatedText(x)).get_annotated_text()
        for x in result_anno
    ]
    write_lines(fn_out, normalized_result_anno)
    return normalized_result_anno
Beispiel #8
0
def get_processed_records(train_file,
                          error_types=[],
                          store_sents=True,
                          skip_dep_features=False):
    all_ann_sents = []
    features, labels = [], []
    error_types_list = []
    features_names = get_full_list_of_features()
    if skip_dep_features:
        dep_features = get_list_of_dependent_features()
        features_names = [x for x in features_names if x not in dep_features]
    with open(train_file, 'r', encoding='utf-8') as f:
        for line in tqdm(f):
            sent = line.strip()
            ann_sent = AnnotatedTokens(AnnotatedText(sent))
            for ann in ann_sent.iter_annotations():
                # apply error_type features
                ann_error_type = get_normalized_error_type(ann)
                if error_types and ann_error_type not in error_types:
                    ann_sent.remove(ann)
                    continue
                # check if ann has all features
                absent_features = [
                    x for x in features_names if x not in list(ann.meta.keys())
                ]
                if not absent_features:
                    # process annotation
                    try:
                        tmp_x, label, ordered_features_names = \
                            process_record(ann.meta, features_names)
                        features.append(tmp_x[:])
                        labels.append(label)
                        error_types_list.append(ann_error_type)
                    except Exception as e:
                        print(ann.meta)
                        continue
                else:
                    ann_sent.remove(ann)
            if store_sents:
                all_ann_sents.append(ann_sent)
    features_names = ordered_features_names
    features = np.array(features)
    labels = np.array(labels)
    return features, labels, features_names, error_types_list, all_ann_sents
def get_kenlm_scores(ann_tokens, ann_list, add_original=False, use_norm=False):
    """Score all annotations in the list using KenLM"""
    kenlm = ngram_client.KenLMClient()
    sent_list = []
    for ann in ann_list:
        sent = AnnotatedTokens(ann_tokens._tokens)
        sent.annotate(ann.start, ann.end, ann.suggestions)
        sent_list.append(sent.get_corrected_text())
    if add_original:
        sent_list.append(sent.get_original_text())
    scores = kenlm.ask_ngrams(sent_list)
    if use_norm:
        scores = [
            scores[i] / max(len(sent_list[i].split()), 1)
            for i in range(len(scores))
        ]
    return scores
def main(args):
    clc_csv_reader = ClcCsvReader(fn=args.fn_clc_csv)
    error_types_bank = ErrorTypesBank()
    target_error_types_list = error_types_bank.patterns22_to_clc89(
        args.target_error_type)
    orig_lines = list()  # original texts
    gold_annotations = list(
    )  # gold corrections in Annotated Text string format
    for _, _, _, _, gold_relabeled_anno_text, gold_error_types_list \
            in clc_csv_reader.iter_items(max_item_number=args.max_item_number):
        # We are not interested in the text samples which doesn''t contain
        # at least one target error type
        if not is_lists_intersection(gold_error_types_list,
                                     target_error_types_list):
            continue
        ann_tokens = AnnotatedTokens(AnnotatedText(gold_relabeled_anno_text))
        for ann in ann_tokens.iter_annotations():
            if ann.meta['error_type'] not in target_error_types_list:
                ann_tokens.remove(ann)
        gold_annotations_renormalized = ann_tokens.get_annotated_text()
        # Add renormalized texts to the lists
        orig_sent = ann_tokens.get_original_text()
        orig_lines.append(orig_sent)
        gold_annotations.append(gold_annotations_renormalized)
    assert len(orig_lines) == len(gold_annotations)
    print('%d lines in unfiltered outputs.' % len(orig_lines))
    gold_annotations_filtered, orig_lines_filtered = filter_by_nosuggestions_in_gold(
        gold_annotations, orig_lines)
    assert len(gold_annotations_filtered) == len(orig_lines_filtered)
    print('%d lines in filtered by NO_SUGGESTION flag outputs.' %
          len(orig_lines_filtered))
    # Write to files
    fn_out_gold_file = args.fn_clc_csv.replace(
        '.csv', f'_{args.target_error_type}_gold.txt')
    fn_out_orig_file = args.fn_clc_csv.replace(
        '.csv', f'_{args.target_error_type}_orig.txt')
    write_lines(fn=fn_out_gold_file, lines=gold_annotations_filtered)
    write_lines(fn=fn_out_orig_file, lines=orig_lines_filtered)
Beispiel #11
0
def get_all_records(train_file, error_types=[], store_sents=True):
    all_sents = read_lines(train_file)
    print("All sents are loaded")
    all_records = []
    all_ann_sents = []
    for i, sent in tqdm(enumerate(all_sents)):
        ann_sent = AnnotatedTokens(AnnotatedText(sent))
        for ann in ann_sent.iter_annotations():
            # apply error_type features
            if error_types:
                ann_error_type = get_normalized_error_type(ann)
                if ann_error_type not in error_types:
                    ann_sent.remove(ann)
                    continue
            # check if ann has features (dict should be big)
            if len(ann.meta.keys()) > 5:
                all_records.append(ann.meta)
            # remove ann if there are not big dict there
            else:
                ann_sent.remove(ann)
        if store_sents:
            all_ann_sents.append(ann_sent)
    return all_records, all_ann_sents
Beispiel #12
0
def filter_by_error_type(raw_list, error_type, default_system_type=None,
                         with_meta=False):
    ebank = ErrorTypesBank()
    output = []
    cnt_target_errors = 0
    cnt_other_errors = 0
    if default_system_type is not None:
        system_type = default_system_type
    for sent in raw_list:
        ann_tokens = AnnotatedTokens(AnnotatedText(sent))
        for ann in ann_tokens.iter_annotations():
            if 'error_type' in ann.meta:
                ann_error_type = ann.meta['error_type']
            elif 'pname' in ann.meta:
                ann_error_type = ann.meta['pname']
            else:
                print(f'Broken annotation {ann}')
                ann_error_type = "OtherError"
            # set system type
            if default_system_type is None:
                system_type = ann.meta['system_type']
            if system_type.startswith('OPC'):
                norm_error_type = ebank.opc_to_patterns22(ann_error_type)
            elif system_type.startswith('UPC'):
                norm_error_type = ebank.upc5_to_patterns22(ann_error_type)
            elif system_type == 'Patterns':
                norm_error_type = \
                    ebank.pname_to_patterns22(ann_error_type)
            elif system_type == "CLC":
                norm_error_type = ebank.clc89_to_patterns22(ann_error_type)
            else:
                print(f'Unknown system {system_type}')
                norm_error_type = "OtherError"
            if norm_error_type != error_type:
                ann_tokens.remove(ann)
                cnt_other_errors += 1
            else:
                cnt_target_errors += 1
        output.append(ann_tokens.get_annotated_text(with_meta=with_meta))
    print(f'Stats: N_target_errors = {cnt_target_errors}, '
          f'N_other_errors = {cnt_other_errors}')
    return output, cnt_target_errors
def main(args):
    all_files = [x for x in os.listdir(args.input_dir) if "tmp" not in x]
    for fname in all_files:
        print(f"Start evaluation {args.system_type} on {fname}")
        if fname.endswith(".txt"):
            fp_ratio = True
        elif fname.endswith(".m2"):
            fp_ratio = False
        else:
            continue
        input_file = os.path.join(args.input_dir, fname)
        if fp_ratio:
            sentences = read_lines(input_file)
        else:
            sentences = get_lines_from_m2_file(input_file)

        # run through system
        if args.system_type == "OPC":
            s_types = ["OPC", "OPC-filtered"]
        elif args.system_type == "UPC":
            s_types = ["UPC"]
        else:
            raise Exception("Unknown system type")
        for system_type in s_types:
            print(f"{system_type} is evaluating")
            combined = [(x, system_type) for x in sentences]
            with ThreadPoolExecutor(args.n_threads) as pool:
                system_out = list(
                    tqdm(pool.map(wrap_check, combined), total=len(combined)))
            system_out = [x.get_annotated_text() for x in system_out]
            print(f"{system_type} system response was got")

            # run system through confidence scorer
            if system_type.endswith("filtered"):
                scorer_list = [None]
            else:
                scorer_list = [None, "LM", "CLF"]
                # scorer_list = [None, "CLF"]
            for scorer in scorer_list:
                print(f"Current scorer is {scorer}")
                if scorer == "CLF":
                    combined = [(x, args.server_path) for x in system_out]
                    with ThreadPoolExecutor(args.n_threads) as pool:
                        scorer_out = list(
                            tqdm(pool.map(wrap_confidence_scorer, combined),
                                 total=len(combined)))
                    # thresholds = [0.1, 0.2, 0.25, 0.3, 0.5]
                    thresholds = [
                        0.1, 0.2, 0.25, 0.3, 0.35, 0.36, 0.38, 0.4, 0.45, 0.5
                    ]
                elif scorer == "LM":
                    with ThreadPoolExecutor(args.n_threads) as pool:
                        scorer_out = list(
                            tqdm(pool.map(wrap_get_lm_scores, system_out),
                                 total=len(combined)))
                    thresholds = [0]
                else:
                    scorer_out = system_out
                    thresholds = [None]
                print("Scores were got")

                # apply thresholds
                if args.error_types is not None:
                    error_types = args.error_types.split()
                else:
                    error_types = None
                for t in thresholds:
                    print(f"The current threshold is {t}")
                    t_out = []
                    for sent in scorer_out:
                        ann_sent = AnnotatedTokens(AnnotatedText(sent))
                        for ann in ann_sent.iter_annotations():
                            ann.meta['system_type'] = system_type
                            et = get_normalized_error_type(ann)
                            if error_types is not None and et not in error_types:
                                ann_sent.remove(ann)
                                continue
                            score = float(ann.meta.get('confidence', 1))
                            if t is not None and score < t:
                                ann_sent.remove(ann)
                        t_out.append(ann_sent.get_annotated_text())
                    if fp_ratio:
                        cnt_errors = sum([
                            len(AnnotatedText(x).get_annotations())
                            for x in t_out
                        ])
                        print(
                            f"\nThe number of errors are equal {cnt_errors}. "
                            f"FP rate {round(100*cnt_errors/len(t_out),2)}%")
                    else:
                        print(f"\nThreshold level is {t}")
                        tmp_filename = input_file.replace(
                            ".m2",
                            f"_{system_type}_{scorer}_above_{t}_tmp.txt")
                        evaluate_from_m2_file(input_file, t_out, tmp_filename)