Ejemplo n.º 1
0
    def __init__(self, model, source_file, target_file, test_source_file, test_target_file,
                 raw_source_file,
                 raw_target_file, num_sentences=400,
                 batch_translate=True):
        self.model = model
        self.source_file = source_file
        self.target_file = target_file
        self.loader = LanguagePairLoader("de", "en", source_file, target_file)
        self.test_loader = LanguagePairLoader("de", "en", test_source_file, test_target_file)
        self.extractor = DomainSpecificExtractor(source_file=raw_source_file, train_source_file=hp.source_file,
                                                 train_vocab_file="train_vocab.pkl")
        self.target_extractor = DomainSpecificExtractor(source_file=raw_target_file, train_source_file=hp.source_file,
                                                        train_vocab_file="train_vocab_en.pkl")
        self.scorer = Scorer()
        self.scores = {}
        self.num_sentences = num_sentences
        self.batch_translate = batch_translate
        self.evaluate_every = 10

        self.metric_bleu_scores = {}
        self.metric_gleu_scores = {}
        self.metric_precisions = {}
        self.metric_recalls = {}

        # Plot each metric
        plt.style.use('seaborn-darkgrid')
        self.palette = sns.color_palette()
Ejemplo n.º 2
0
    def run(self):
        loader = LanguagePairLoader("de", "en", self.source_file, self.target_file)
        _, _, pairs = loader.load()

        pairs = pairs[:self.num_sentences]
        # Translate sources
        sources, targets, translations = [p[0] for p in pairs], [p[1] for p in pairs], []

        extractor = DomainSpecificExtractor(source_file="data/khresmoi.tok.de",
                                            train_source_file=hp.source_file,
                                            train_vocab_file="train_vocab.pkl")
        keyphrases = extractor.extract_keyphrases(n_results=100)
        print(keyphrases)

        for i, pair in enumerate(pairs):
            if i % 10 == 0:
                print("Translated {} of {}".format(i, len(pairs)))
            translation, attn, _ = self.model.translate(pair[0], beam_size=1)
            translations.append(" ".join(translation[:-1]))
            scores = self.scorer.compute_scores(pair[0], " ".join(translation), attn, keyphrases)

            for metric in scores:
                if metric == "coverage_penalty" and scores[metric] > 80:
                    continue
                if metric == "keyphrase_score" and scores[metric] == 0:
                    continue

                if not metric in self.metric_to_gleu:
                    self.metric_to_gleu[metric] = {}
                if not scores[metric] in self.metric_to_gleu[metric]:
                    self.metric_to_gleu[metric][scores[metric]] = []
                gleu = compute_cter(pair[1], " ".join(translation[:-1]))
                self.all_gleu_scores.append(gleu)
                self.metric_to_gleu[metric][scores[metric]].append(gleu)
Ejemplo n.º 3
0
    def extract_keyphrases(self, n_results=20):
        train_vocab = None
        if os.path.isfile(self.train_vocab_file):
            train_vocab = pickle.load(open(self.train_vocab_file, "rb"))
        else:
            train_vocab = Counter()
            train_loader = LanguagePairLoader("de", "en", self.train_source_file, self.train_source_file)
            train_in, train_out, train_pairs = train_loader.load()
            for source, _ in train_pairs:
                for word in source.replace("@@ ", "").split(" "):
                    train_vocab[word] += 1
            pickle.dump(train_vocab, open(self.train_vocab_file, "wb"))

        loader = LanguagePairLoader("de", "en", self.source_file, self.source_file)
        in_lang, _, pairs = loader.load()

        domain_words = []
        for word in in_lang.word2count:
            if train_vocab[word] < self.frequency_threshold and in_lang.word2count[word] > 0:
                freq = 0
                for source, _ in pairs:
                    if word.lower() in source.lower():
                        freq += 1
                domain_words.append((word, freq))

        domain_words = sorted(domain_words, key=lambda w: in_lang.word2count[w[0]], reverse=True)
        return domain_words[:n_results]
Ejemplo n.º 4
0
    seq2seq_model.decoder.load_state_dict(decoder_state)


def compute_bleu(targets, translations):
    import nltk

    references, translations = [
        [target.replace("@@ ", "").split(" ")] for target in targets
    ], [t.replace("@@ ", "").split(" ") for t in translations]

    bleu = nltk.translate.bleu_score.corpus_bleu(references, translations)
    return bleu


model = load_model()
loader = LanguagePairLoader("de", "en", "data/auto.bpe.de", "data/auto.bpe.en")
_, _, pairs = loader.load()

sources, targets = [p[0] for p in pairs], [p[1] for p in pairs]
translations = []
for source in sources:
    translation, _, _ = model.translate(source)
    translations.append(" ".join(translation[:-1]))

print("BLEU {}".format(compute_bleu(targets, translations)))

encoder_optimizer_state, decoder_optimizer_state = retrain_iters(
    model,
    pairs, [],
    batch_size=20,
    encoder_optimizer_state=None,
Ejemplo n.º 5
0
class AveragedMetricExperiment:
    def __init__(self, model, source_file, target_file, raw_source_file, raw_target_file, num_sentences=400):
        self.model = model
        self.source_file = source_file
        self.target_file = target_file
        self.loader = LanguagePairLoader("de", "en", source_file, target_file)
        self.extractor = DomainSpecificExtractor(source_file=raw_source_file, train_source_file=hp.source_file,
                                                 train_vocab_file="train_vocab.pkl")
        self.target_extractor = DomainSpecificExtractor(source_file=raw_target_file, train_source_file=hp.target_file,
                                                        train_vocab_file="train_vocab_en.pkl")
        self.scorer = Scorer()
        self.scores = {}
        self.num_sentences = num_sentences

        self.metric_bleu_scores = {}
        self.metric_gleu_scores = {}
        self.metric_precisions = {}
        self.metric_recalls = {}
        self.cer = {}

        # Plot each metric
        plt.style.use('seaborn-darkgrid')
        self.palette = sns.color_palette()

    def save_data(self):
        prefix = "averaged_"
        pickle.dump(self.metric_bleu_scores, open(prefix + "metric_bleu_scores.pkl", "wb"))
        pickle.dump(self.metric_gleu_scores, open(prefix + "metric_gleu_scores.pkl", "wb"))
        pickle.dump(self.metric_precisions, open(prefix + "metric_precisions.pkl", "wb"))
        pickle.dump(self.metric_recalls, open(prefix + "metric_recalls.pkl", "wb"))
        pickle.dump(self.cer, open(prefix + "metric_cer.pkl", "wb"))
        print("Saved all scores")

    def run(self):
        _, _, pairs = self.loader.load()
        random.seed(2018)
        random.shuffle(pairs)

        pairs = pairs[:self.num_sentences]

        sources, targets, translations = [p[0] for p in pairs], [p[1] for p in pairs], []

        keyphrases = self.extractor.extract_keyphrases(n_results=100)
        print(keyphrases)
        target_keyphrases = self.target_extractor.extract_keyphrases(n_results=100)
        print(target_keyphrases)

        for i, pair in enumerate(pairs):
            if i % 10 == 0:
                print("Translated {} of {}".format(i, len(pairs)))
            translation, attn, _ = self.model.translate(pair[0])
            translations.append(" ".join(translation[:-1]))

            metrics_scores = self.scorer.compute_scores(pair[0], " ".join(translation[:-1]), attn, keyphrases)
            for metric in metrics_scores:
                if metric not in self.scores:
                    self.scores[metric] = []
                self.scores[metric].append(metrics_scores[metric])

        metrics = [
            # "coverage_penalty",
            # "coverage_deviation_penalty",
            # "confidence",
            # "length",
            # "ap_in",
            # "ap_out",
            # "random",
            "keyphrase_score"
        ]
        n_iters = 1
        for i, metric in enumerate(metrics):
            avg_bleus = [0 for _ in range(1, 100 // (step_size * 2) + 1)]
            self.metric_bleu_scores[metric] = []
            self.metric_gleu_scores[metric] = []
            self.metric_precisions[metric] = []
            self.metric_recalls[metric] = []
            self.cer[metric] = []
            for j in range(n_iters):
                self.evaluate_metric(sources, targets, translations,
                                     self.scores[metric] if metric != "random" else [],
                                     metric,
                                     target_keyphrases,
                                     need_sort=True if metric != "random" else False,
                                     reverse=sort_direction[metric] if metric != "random" else True)

                # plt.plot(x, delta_bleus, marker='', linestyle="--", color=self.palette[i], linewidth=1, alpha=0.9,
                #        label=metric)
            self.save_data()

    def shuffle_list(self, *ls):
        l = list(zip(*ls))

        random.shuffle(l)
        return zip(*l)

    def evaluate_metric(self, sources, targets, translations, scores, metric, target_keyphrases,
                        need_sort=True,
                        reverse=False):
        print("Evaluating {}".format(metric))
        base_bleu = compute_bleu(targets, translations)
        print("Base BLEU: {}".format(base_bleu))
        # Sort by metric
        if need_sort:
            sorted_sentences = [(x, y, z) for _, x, y, z in
                                sorted(zip(scores, sources, targets, translations), reverse=reverse)]
            sources, targets, translations = zip(*sorted_sentences)
        else:
            sources, targets, translations = self.shuffle_list(sources, targets, translations)

        n = len(sources)
        encoder_optimizer_state, decoder_optimizer_state = None, None

        corrected_translations = []

        cer_improvement = []
        curr_cer = 0

        for i in range(1, n + 1):
            print()
            print("{}: Correcting {} of {} sentences".format(metric, i, n))

            curr_end = i

            # Compute BLEU before training for comparison
            pretraining_bleu = compute_bleu(targets[:curr_end], translations[:curr_end])
            pretraining_gleu = compute_gleu(targets[:curr_end], translations[:curr_end])
            prerecall = unigram_recall(target_keyphrases, targets[:curr_end], translations[:curr_end])
            preprecision = unigram_precision(target_keyphrases, targets[:curr_end], translations[:curr_end])

            precer = cer(targets[i - 1].replace("@@ ", "").split(), translations[i - 1].replace("@@ ", "").split())

            translation, _, _ = seq2seq_model.translate(sources[i - 1])
            corrected_translations.append(" ".join(translation[:-1]))

            postcer = cer(targets[i - 1].replace("@@ ", "").split(),
                          " ".join(translation[:-1]).replace("@@ ", "").split())
            curr_cer = precer - postcer
            cer_improvement.append(curr_cer)

            # Compute posttraining BLEU
            posttraining_bleu = compute_bleu(targets[:curr_end], corrected_translations)
            posttraining_gleu = compute_gleu(targets[:curr_end], corrected_translations)

            postrecall = unigram_recall(target_keyphrases, targets[:curr_end], corrected_translations)
            postprecision = unigram_precision(target_keyphrases, targets[:curr_end], corrected_translations)
            print("Delta Recall {} -> {}".format(prerecall, postrecall))
            print("Delta Precision {} -> {}".format(preprecision, postprecision))
            print("Delta BLEU: {} -> {}".format(pretraining_bleu, posttraining_bleu))
            print("Delta CER: {} -> {}".format(precer, postcer))

            self.metric_bleu_scores[metric].append((pretraining_bleu, posttraining_bleu))
            self.metric_gleu_scores[metric].append((pretraining_gleu, posttraining_gleu))
            self.metric_recalls[metric].append((prerecall, postrecall))
            self.metric_precisions[metric].append((preprecision, postprecision))

            # Now train, and compute BLEU again
            encoder_optimizer_state, decoder_optimizer_state = retrain_iters(self.model,
                                                                             [[sources[i - 1],
                                                                               targets[i - 1]]], [],
                                                                             batch_size=1,
                                                                             encoder_optimizer_state=encoder_optimizer_state,
                                                                             decoder_optimizer_state=decoder_optimizer_state,
                                                                             n_epochs=1, learning_rate=0.00005,
                                                                             weight_decay=1e-3)

        self.cer[metric] = cer_improvement
        reload_model(self.model)
        return None

    def plot(self):
        plt.xlabel('% Corrected Sentences')
        plt.ylabel('Δ BLEU')
        # Add titles
        plt.title("BLEU Change for Metrics", loc='center', fontsize=12, fontweight=0)
        # Add legend
        plt.legend(loc='lower right', ncol=1)
        plt.savefig('bleu_deltas.png')
Ejemplo n.º 6
0
sys.path.insert(0, './myseq2seq')

from models import AttnDecoderRNN, EncoderRNN, LSTMAttnDecoderRNN, LSTMEncoderRNN

import hp
from hp import PAD_token, SOS_token, EOS_token, MIN_LENGTH, MAX_LENGTH, hidden_size, batch_size, n_epochs, embed_size
from data_loader import LanguagePairLoader, DateConverterLoader
from models import Seq2SeqModel
from train import train_iters
import pickle
import os.path

use_cuda = torch.cuda.is_available()

loader = LanguagePairLoader("de", "en")
eval_loader = LanguagePairLoader("de", "en", hp.source_test_file,
                                 hp.target_test_file)

input_lang, output_lang, pairs = None, None, None

_, _, eval_pairs = eval_loader.load()

if hp.load_vocabs or not os.path.isfile(hp.prefix + "input.dict"):
    input_lang, output_lang, pairs = loader.load()
    pickle.dump(input_lang, open(hp.prefix + "input.dict", "wb"))
    pickle.dump(output_lang, open(hp.prefix + "output.dict", "wb"))
else:
    input_lang = pickle.load(open(hp.prefix + "input.dict", "rb"))
    output_lang = pickle.load(open(hp.prefix + "output.dict", "rb"))