예제 #1
0
def augment(features, source_path, input_nbest_path, output_nbest_path):
    ''' Function to augment the n-best list with a feature function
     :param feature: The feature function object
     :param source_path: Path to the original source sentences (maybe required for the feature function)
     :param input_nbest_path: Path to the n-best file
     :param output_nbest_path: Path to the output n-best file
    '''
    # Initialize NBestList objects
    logger.info('Initializing Nbest lists')
    input_nbest = NBestList(input_nbest_path, mode='r')
    output_nbest = NBestList(output_nbest_path, mode='w')

    # Load the source sentences
    logger.info('Loading source sentences')
    src_sents = codecs.open(source_path, mode='r', encoding='UTF-8')

    # For each of the item in the n-best list, append the feature
    sent_count = 0
    for group, src_sent in zip(input_nbest, src_sents):
        candidate_count = 0
        for item in group:
            for feature in features:
                item.append_feature(
                    feature.name,
                    feature.get_score(src_sent, item.hyp,
                                      (sent_count, candidate_count)))
            output_nbest.write(item)
            candidate_count += 1
        sent_count += 1
        if (sent_count % 100 == 0):
            logger.info('Augmented ' + L.b_yellow(str(sent_count)) +
                        ' sentences.')
    output_nbest.close()
예제 #2
0
파일: m2.py 프로젝트: shamilcm/crosentgec-1
def m2_extractor(nbest_path, m2_ref, stats_file, features_file):
    nbest = NBestList(nbest_path, mode='r')
    group_count=-1
    source_sentences, gold_edits = m2_load_annotation(m2_ref, 'all')


    # M2Scorer Parameters
    max_unchanged_words=2
    beta = 0.5
    ignore_whitespace_casing= False
    verbose = False
    very_verbose = False

    with open(features_file, 'w') as ffeat, open(stats_file, 'w') as fstats:
        for group, source_sentence, golds_set in zip(nbest,source_sentences, gold_edits):
            group_count += 1
            candidate_count=-1
            candidates =  list(group)
            for candidate in candidates:
                candidate_count += 1
                feature_score_dict = feature_extractor(candidate.features)
                # write to features file
                p,r,f, stats = m2_levenshtein.batch_multi_pre_rec_f1([candidate.hyp], [source_sentence], [golds_set],  max_unchanged_words, beta, ignore_whitespace_casing, verbose, very_verbose, stats=True )
                if candidate_count == 0:
                    # header for each group
                    feature_list = ['{}_{}'.format(feature_name, idx) for feature_name, feature_values in feature_score_dict.items()  for idx, feature_value in enumerate(feature_values)  ]
                    ffeat.write("FEATURES_TXT_BEGIN_0 {} {} {} {}\n".format(group_count, len(candidates), len(feature_list), ' '.join(feature_list)))
                    num_stats = 4
                    fstats.write("SCORES_TXT_BEGIN_0 {} {} {} M2Scorer\n".format(group_count, len(candidates), num_stats))
                # write each line to features/stats
                ffeat.write(' '.join([' '.join([str(val) for val in feature_values]) for feature_values in feature_score_dict.values()]) + '\n')
                fstats.write('{} {} {} {} \n'.format(stats['num_correct'], stats['num_proposed'], stats['num_gold'], stats['num_src_tokens']))
            # footer for each group
            ffeat.write('FEATURES_TXT_END_0\n')
            fstats.write('SCORES_TXT_END_0\n')
            # logging
            logger.info("processed {} groups".format(group_count))
예제 #3
0
L.set_logger(os.path.abspath(args.out_dir), 'train_log.txt')
L.print_args(args)

output_nbest_path = args.out_dir + '/augmented.nbest'
shutil.copy(args.input_nbest, output_nbest_path)

with open(args.weights, 'r') as input_weights:
    lines = input_weights.readlines()
    if len(lines) > 1:
        L.warning(
            "Weights file has more than one line. I'll read the 1st and ignore the rest."
        )
    weights = np.asarray(lines[0].strip().split(" "), dtype=float)

prefix = os.path.basename(args.input_nbest)
input_aug_nbest = NBestList(output_nbest_path, mode='r')
output_nbest = NBestList(args.out_dir + '/' + prefix + '.reranked.nbest',
                         mode='w')
output_1best = codecs.open(args.out_dir + '/' + prefix + '.reranked.1best',
                           mode='w',
                           encoding='UTF-8')


def is_number(s):
    try:
        float(s)
        return True
    except ValueError:
        return False