Exemplo n.º 1
0
def eval(epoch):
    model.eval()
    reference, candidate, source, alignments = [], [], [], []
    for raw_src, src, src_len, raw_tgt, tgt, tgt_len in validloader:
        if len(opt.gpus) > 1:
            samples, targets, alignment = model.module.sample(
                src, src_len, tgt, tgt_len)
        else:
            samples, targets, alignment = model.sample(src, src_len, tgt,
                                                       tgt_len)

        candidate += [tgt_vocab.convertToLabels(s, dict.EOS) for s in samples]
        source += raw_src
        reference += raw_tgt
        alignments += [align for align in alignment]

    if opt.unk:
        cands = []
        for s, c, align in zip(source, candidate, alignments):
            #cand = [word if word != dict.UNK_WORD or idx >= len(s) else s[idx] for word, idx in zip(c, align)]
            cand = []
            for word, idx in zip(c, align):
                if word == dict.UNK_WORD and idx < len(s):
                    try:
                        cand.append(s[idx])
                    except:
                        cand.append(word)
                        print("%d %d\n" % (len(s), idx))
                else:
                    cand.append(word)
            cands.append(cand)
        candidate = cands

    score = {}

    if 'bleu' in config.metric:
        result = utils.eval_bleu(reference, candidate, log_path, config)
        score['bleu'] = float(result.split()[2][:-1])
        logging(result)

    if 'rouge' in config.metric:
        result = utils.eval_rouge(reference, candidate, log_path)
        try:
            score['rouge'] = result['F_measure'][0]
            logging("F_measure: %s Recall: %s Precision: %s\n" %
                    (str(result['F_measure']), str(
                        result['recall']), str(result['precision'])))
        except:
            logging("Failed to compute rouge score.\n")
            score['rouge'] = 0.0

    return score
Exemplo n.º 2
0
def eval(epoch):
    model.eval()
    reference, candidate, source, alignments = [], [], [], []
    count, total_count = 0, len(validset)

    for batch in validloader:
        raw_src, src, src_len, raw_tgt, tgt, tgt_len = \
            batch['raw_src'], batch['src'], batch['src_len'], batch['raw_tgt'], batch['tgt'], batch['tgt_len']

        if 'num_oovs' in batch.keys():
            num_oovs = batch['num_oovs']
            oovs = batch['oovs']
        else:
            num_oovs = 0
            oovs = None

        if config.beam_size == 1:
            samples, alignment = model.sample(src, src_len, num_oovs=num_oovs)
        else:
            samples, alignment = model.beam_sample(src,
                                                   src_len,
                                                   beam_size=config.beam_size)

        if oovs is not None:
            candidate += [
                tgt_vocab.convertToLabels(s, dict.EOS, oovs=oov)
                for s, oov in zip(samples, oovs)
            ]
        else:
            candidate += [
                tgt_vocab.convertToLabels(s, dict.EOS) for s in samples
            ]
        source += raw_src
        reference += raw_tgt
        alignments += [align for align in alignment]

        count += len(raw_src)
        utils.progress_bar(count, total_count)

    if opt.unk:
        ###replace unk
        cands = []
        for s, c, align in zip(source, candidate, alignments):
            cand = []
            for word, idx in zip(c, align):
                if word == dict.UNK_WORD and idx < len(s):
                    try:
                        cand.append(s[idx])
                    except:
                        cand.append(word)
                        print("%d %d\n" % (len(s), idx))
                else:
                    cand.append(word)

            cands.append(cand)
        candidate = cands

    score = {}

    if hasattr(config, 'convert'):
        candidate = utils.convert_to_char(candidate)
        reference = utils.convert_to_char(reference)

    if 'bleu' in config.metric:
        result = utils.eval_bleu(reference, candidate, log_path, config)
        score['bleu'] = float(result.split()[2][:-1])
        logging(result)

    if 'rouge' in config.metric:
        result = utils.eval_rouge(reference, candidate, log_path)
        try:
            score['rouge'] = result['F_measure'][0]
            logging("F_measure: %s Recall: %s Precision: %s\n" %
                    (str(result['F_measure']), str(
                        result['recall']), str(result['precision'])))
            #optim.updateLearningRate(score=score['rouge'], epoch=epoch)
        except:
            logging("Failed to compute rouge score.\n")
            score['rouge'] = 0.0

    if 'multi_rouge' in config.metric:
        result = utils.eval_multi_rouge(reference, candidate, log_path)
        try:
            score['multi_rouge'] = result['F_measure'][0]
            logging("F_measure: %s Recall: %s Precision: %s\n" %
                    (str(result['F_measure']), str(
                        result['recall']), str(result['precision'])))
        except:
            logging("Failed to compute rouge score.\n")
            score['multi_rouge'] = 0.0

    if 'SARI' in config.metric:
        result = utils.eval_SARI(source, reference, candidate, log_path,
                                 config)
        logging("SARI score is: %.2f\n" % result)
        score['SARI'] = result

    return score
Exemplo n.º 3
0
def eval(epoch):
    model.eval()
    reference, candidate, source, alignments = [], [], [], []
    count, total_count = 0, len(testset)

    for batch in testloader:
        raw_src, src, src_len, raw_tgt, tgt, tgt_len = \
            batch['raw_src'], batch['src'], batch['src_len'], batch['raw_tgt'], batch['tgt'], batch['tgt_len']

        if 'num_oovs' in batch.keys():
            num_oovs = batch['num_oovs']
            oovs = batch['oovs']
        else:
            num_oovs = 0
            oovs = None

        if beam_size == 1:
            samples, alignment = model.sample(src, src_len, num_oovs=num_oovs)
        else:
            samples, alignment = model.beam_sample(src,
                                                   src_len,
                                                   beam_size=beam_size,
                                                   num_oovs=num_oovs)

        candidate += [tgt_vocab.convertToLabels(s, dict.EOS) for s in samples]
        source += raw_src
        reference += raw_tgt
        alignments += [align for align in alignment]

        count += len(raw_src)
        utils.progress_bar(count, total_count)

    #if opt.unk:
    ###replace unk
    cands = []
    for s, c, align in zip(source, candidate, alignments):
        cand = []
        for word, idx in zip(c, align):
            if word == dict.UNK_WORD and idx < len(s):
                try:
                    cand.append(s[idx])
                except:
                    cand.append(word)
                    print("%d %d\n" % (len(s), idx))
            else:
                cand.append(word)
        if opt.reduce:
            phrase_set = {}
            mask = [1 for _ in range(len(cand))]
            for id in range(1, len(cand)):
                phrase = cand[id - 1] + " " + cand[id]
                if phrase in phrase_set.keys():
                    mask[id - 1] = 0
                    mask[id] = 0
                else:
                    phrase_set[phrase] = True
            cand = [word for word, m in zip(cand, mask) if m == 1]
        cands.append(cand)
    candidate = cands

    if opt.group:
        lengths = [
            90, 95, 100, 105, 110, 115, 120, 125, 130, 135, 140, 145, 150
        ]
        group_cand, group_ref = collections.OrderedDict(
        ), collections.OrderedDict()
        for length in lengths:
            group_cand[length] = []
            group_ref[length] = []
        total_length = []
        for s, c, r in zip(source, candidate, reference):
            length = len(s)
            total_length.append(length)
            for l in lengths:
                if length <= l:
                    group_ref[l].append(r)
                    group_cand[l].append(c)
                    break
        print("min length %d, max length %d" %
              (min(total_length), max(total_length)))
        if 'rouge' in config.metric:
            for l in lengths:
                print("length %d, count %d" % (l, len(group_cand[l])))
                result = utils.eval_rouge(group_ref[l], group_cand[l],
                                          log_path)
                try:
                    logging(
                        "length: %d F_measure: %s Recall: %s Precision: %s\n\n"
                        % (l, str(result['F_measure']), str(
                            result['recall']), str(result['precision'])))
                except:
                    logging("Failed to compute rouge score.\n")

    score = {}

    if 'bleu' in config.metric:
        result = utils.eval_bleu(reference, candidate, log_path, config)
        score['bleu'] = float(result.split()[2][:-1])
        logging(result)

    if 'rouge' in config.metric:
        result = utils.eval_rouge(reference, candidate, log_path)
        try:
            score['rouge'] = result['F_measure'][0]
            logging("F_measure: %s Recall: %s Precision: %s\n" %
                    (str(result['F_measure']), str(
                        result['recall']), str(result['precision'])))
        except:
            logging("Failed to compute rouge score.\n")
            score['rouge'] = 0.0

    if 'multi_rouge' in config.metric:
        result = utils.eval_multi_rouge(reference, candidate, log_path)
        try:
            score['multi_rouge'] = result['F_measure'][0]
            logging("F_measure: %s Recall: %s Precision: %s\n" %
                    (str(result['F_measure']), str(
                        result['recall']), str(result['precision'])))
        except:
            logging("Failed to compute rouge score.\n")
            score['multi_rouge'] = 0.0

    if 'SARI' in config.metric:
        result = utils.eval_SARI(source, reference, candidate, log_path,
                                 config)
        logging("SARI score is: %.2f\n" % result)
        score['SARI'] = result

    return score