Пример #1
0
    def __init__(self, model, source_file, target_file, test_source_file, test_target_file,
                 raw_source_file,
                 raw_target_file, num_sentences=400,
                 batch_translate=True):
        self.model = model
        self.source_file = source_file
        self.target_file = target_file
        self.loader = LanguagePairLoader("de", "en", source_file, target_file)
        self.test_loader = LanguagePairLoader("de", "en", test_source_file, test_target_file)
        self.extractor = DomainSpecificExtractor(source_file=raw_source_file, train_source_file=hp.source_file,
                                                 train_vocab_file="train_vocab.pkl")
        self.target_extractor = DomainSpecificExtractor(source_file=raw_target_file, train_source_file=hp.source_file,
                                                        train_vocab_file="train_vocab_en.pkl")
        self.scorer = Scorer()
        self.scores = {}
        self.num_sentences = num_sentences
        self.batch_translate = batch_translate
        self.evaluate_every = 10

        self.metric_bleu_scores = {}
        self.metric_gleu_scores = {}
        self.metric_precisions = {}
        self.metric_recalls = {}

        # Plot each metric
        plt.style.use('seaborn-darkgrid')
        self.palette = sns.color_palette()
Пример #2
0
    def run(self):
        loader = LanguagePairLoader("de", "en", self.source_file, self.target_file)
        _, _, pairs = loader.load()

        pairs = pairs[:self.num_sentences]
        # Translate sources
        sources, targets, translations = [p[0] for p in pairs], [p[1] for p in pairs], []

        extractor = DomainSpecificExtractor(source_file="data/khresmoi.tok.de",
                                            train_source_file=hp.source_file,
                                            train_vocab_file="train_vocab.pkl")
        keyphrases = extractor.extract_keyphrases(n_results=100)
        print(keyphrases)

        for i, pair in enumerate(pairs):
            if i % 10 == 0:
                print("Translated {} of {}".format(i, len(pairs)))
            translation, attn, _ = self.model.translate(pair[0], beam_size=1)
            translations.append(" ".join(translation[:-1]))
            scores = self.scorer.compute_scores(pair[0], " ".join(translation), attn, keyphrases)

            for metric in scores:
                if metric == "coverage_penalty" and scores[metric] > 80:
                    continue
                if metric == "keyphrase_score" and scores[metric] == 0:
                    continue

                if not metric in self.metric_to_gleu:
                    self.metric_to_gleu[metric] = {}
                if not scores[metric] in self.metric_to_gleu[metric]:
                    self.metric_to_gleu[metric][scores[metric]] = []
                gleu = compute_cter(pair[1], " ".join(translation[:-1]))
                self.all_gleu_scores.append(gleu)
                self.metric_to_gleu[metric][scores[metric]].append(gleu)
Пример #3
0
def retranslate(document_id):
    document = get_document(document_id)
    scorer = Scorer()
    extractor = DomainSpecificExtractor(source_file=document.filepath, src_lang=SRC_LANG, tgt_lang=TGT_LANG,
                                        train_source_file=f".data/wmt14/train.tok.clean.bpe.32000.{SRC_LANG}",
                                        train_vocab_file=f".data/vocab/train_vocab_{SRC_LANG}.pkl")
    keyphrases = extractor.extract_keyphrases()
    num_changes = 0

    for i, sentence in enumerate(document.sentences):
        sentence, num_changes = retranslateSentenceWithId(i, sentence, scorer, keyphrases, num_changes)

    save_document(document, document_id)
    return jsonify({"numChanges": num_changes})
Пример #4
0
def retranslateSentence(document_id, sentence_id, beam_size, att_layer):

    document = get_document(document_id)
    scorer = Scorer()
    extractor = DomainSpecificExtractor(source_file=document.filepath, src_lang=SRC_LANG, tgt_lang=TGT_LANG,
                                        train_source_file=f".data/wmt14/train.tok.clean.bpe.32000.{SRC_LANG}",
                                        train_vocab_file=f".data/vocab/train_vocab_{SRC_LANG}.pkl")
    keyphrases = extractor.extract_keyphrases()
    num_changes = 0

    retranslateSentenceWithId(sentence_id, document.sentences[int(sentence_id)], scorer, keyphrases,
                              num_changes, int(beam_size), int(att_layer), force=True)
    save_document(document, document_id)

    return jsonify({})
Пример #5
0
def correctTranslation():
    data = request.get_json()
    translation = data["translation"]
    beam = data["beam"]
    document_unk_map = data["document_unk_map"]
    attention = data["attention"]
    document_id = data["document_id"]
    sentence_id = data["sentence_id"]

    document = get_document(document_id)

    extractor = DomainSpecificExtractor(source_file=document.filepath, src_lang=SRC_LANG, tgt_lang=TGT_LANG,
                                        train_source_file=f".data/wmt14/train.tok.clean.bpe.32000.{SRC_LANG}",
                                        train_vocab_file=f".data/vocab/train_vocab_{SRC_LANG}.pkl")
    keyphrases = extractor.extract_keyphrases()

    for key in document_unk_map:
        if key not in document.unk_map:
            document.unk_map[key] = document_unk_map[key]
        else:
            # Merge list values
            document.unk_map[key] = list(set(document.unk_map[key]) | set(document_unk_map[key]))

    sentence = document.sentences[int(sentence_id)]

    if translation != sentence.translation:
        sentence.diff = html_diff(sentence.translation[:-4].replace("@@ ", ""),
                                  translation[:-4].replace("@@ ", ""))
    sentence.translation = translation
    sentence.corrected = True
    sentence.flagged = False
    sentence.attention = attention
    sentence.beam = beam

    scorer = Scorer()
    score = scorer.compute_scores(sentence.source, sentence.translation, attention, keyphrases, "")
    score["order_id"] = sentence.score["order_id"]
    sentence.score = score

    document.sentences[int(sentence_id)] = sentence

    save_document(document, document_id)

    return jsonify({})
Пример #6
0
    def run(self, src_lang, tgt_lang, dir, translationFile, scoresFile,
            attFile):
        loader = LanguagePairLoader(src_lang, tgt_lang, self.source_file,
                                    self.target_file)
        _, _, pairs = loader.load()

        loader2 = LanguagePairLoader(src_lang, tgt_lang, self.source_file2,
                                     self.target_file2)
        _, _, pairs2 = loader2.load()

        # concatenate both sets => all 1500 sentences
        pairs = pairs + pairs2

        self.pairs = pairs[:self.num_sentences]

        # Translate sources
        sources, targets, translations = [p[0] for p in self.pairs
                                          ], [p[1] for p in self.pairs], []

        extractor = DomainSpecificExtractor(
            source_file=self.source_file,
            src_lang=src_lang,
            tgt_lang=tgt_lang,
            train_source_file=
            f".data/wmt14/train.tok.clean.bpe.32000.{src_lang}",
            train_vocab_file=f".data/vocab/train_vocab_{src_lang}.pkl")

        keyphrases = extractor.extract_keyphrases(n_results=100)

        self.translationList = []
        attentionList = []
        self.scoresList = []
        prefix = "_experiments/translated_beam3"

        if os.path.isfile(os.path.join(prefix, translationFile)) \
                and os.path.isfile(os.path.join(prefix, scoresFile)) \
                and os.path.isfile(os.path.join(prefix, attFile)):
            print("Translation reloaded")
            with open(os.path.join(prefix, translationFile), 'rb') as f:
                self.translationList = pickle.load(f)
            with open(os.path.join(prefix, attFile), 'rb') as f:
                attentionList = pickle.load(f)
            with open(os.path.join(prefix, scoresFile), 'rb') as f:
                self.scoresList = pickle.load(f)

        else:
            for i, pair in enumerate(self.pairs):
                if i % 10 == 0:
                    print("Translated {} of {}".format(i, len(self.pairs)))

                translation, attn, _ = self.model.translate(
                    pair[0], beam_size=self.beam_size)
                translations.append(" ".join(translation[:-1]))

                scores = self.scorer.compute_scores(pair[0],
                                                    " ".join(translation),
                                                    attn, keyphrases, "")

                self.translationList.append(translation)
                attentionList.append(attn)
                self.scoresList.append(scores)

            pickle.dump(self.translationList,
                        open(os.path.join(dir, translationFile), "wb"))
            pickle.dump(self.scoresList,
                        open(os.path.join(dir, scoresFile), "wb"))
            pickle.dump(attentionList, open(os.path.join(dir, attFile), "wb"))

        for i, pair in enumerate(self.pairs):
            if i % 10 == 0:
                print("Processing {} of {}".format(i, len(self.pairs)))

            for metric in self.scoresList[i]:
                #if metric == "coverage_penalty" and self.scoresList[i][metric] > 45: # remove some outliers
                #    continue
                #if metric == "keyphrase_score" and self.scoresList[i][metric] == 0:
                #    continue

                if not metric in self.metric_to_cter:
                    self.metric_to_cter[metric] = {}
                if not self.scoresList[i][metric] in self.metric_to_cter[
                        metric]:
                    self.metric_to_cter[metric][self.scoresList[i]
                                                [metric]] = []

                cter = compute_cter(pair[1],
                                    " ".join(self.translationList[i][:-1]))
                self.all_cter_scores.append(cter)
                self.metric_to_cter[metric][self.scoresList[i][metric]].append(
                    cter)
Пример #7
0
    seq2seq_model = Seq2SeqModel(encoder, decoder, input_lang, output_lang)

    return seq2seq_model


def reload_model(seq2seq_model):
    checkpoint = torch.load(hp.checkpoint_name)
    encoder_state = checkpoint["encoder"]
    decoder_state = checkpoint["decoder"]

    seq2seq_model.encoder.load_state_dict(encoder_state)
    seq2seq_model.decoder.load_state_dict(decoder_state)


def keyphrase_score(sentence, keyphrases):
    score = 0

    for word in sentence.split(" "):
        for keyphrase, freq in keyphrases:
            score += word.lower().count(keyphrase.lower()) * freq
    return score


extractor = DomainSpecificExtractor(source_file="data/medical.tok.de",
                                    train_source_file=hp.source_file,
                                    train_vocab_file="train_vocab.pkl")

words = extractor.extract_keyphrases()

print(words)
Пример #8
0
    def __init__(
        self,
        model,
        src_lang,
        tgt_lang,
        model_type,
        source_file,
        target_file,
        test_source_file,
        test_target_file,
        dir,
        evaluate_every=10,
        num_sentences=400,
        num_sentences_test=500,
        reuseCalculatedTranslations=False,
        reuseInitialTranslations=False,
        initialTranslationFile="",
        initialScoreFile="",
        initialTestTranslationFile="",
        translationFile="",
        batch_translate=True,
    ):

        self.model = model
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang
        self.model_type = model_type
        self.source_file = source_file
        self.target_file = target_file
        self.loader = LanguagePairLoader(src_lang, tgt_lang, source_file,
                                         target_file)
        self.test_loader = LanguagePairLoader(src_lang, tgt_lang,
                                              test_source_file,
                                              test_target_file)

        self.extractor = DomainSpecificExtractor(
            source_file=source_file,
            src_lang=src_lang,
            tgt_lang=tgt_lang,
            train_source_file=
            f".data/wmt14/train.tok.clean.bpe.32000.{src_lang}",
            train_vocab_file=f".data/vocab/train_vocab_{src_lang}.pkl")

        self.target_extractor = DomainSpecificExtractor(
            source_file=target_file,
            src_lang=tgt_lang,
            tgt_lang=src_lang,
            train_source_file=
            f".data/wmt14/train.tok.clean.bpe.32000.{tgt_lang}",
            train_vocab_file=f".data/vocab/train_vocab_{tgt_lang}.pkl")

        self.scorer = Scorer()
        self.scores = {}
        self.num_sentences = num_sentences
        self.num_sentences_test = num_sentences_test
        self.batch_translate = batch_translate
        self.evaluate_every = evaluate_every
        self.reuseCalculatedTranslations = reuseCalculatedTranslations
        self.reuseInitialTranslations = reuseInitialTranslations

        self.initialTranslationFile = initialTranslationFile
        self.initialScoreFile = initialScoreFile
        self.initialTestTranslationFile = initialTestTranslationFile
        self.translationFile = translationFile

        self.metric_bleu_scores = {}
        self.metric_gleu_scores = {}
        self.metric_precisions = {}
        self.metric_recalls = {}

        self.prefix = "_experiments/retrain_beam3"
        self.dir = dir
Пример #9
0
class MetricExperiment:
    def __init__(
        self,
        model,
        src_lang,
        tgt_lang,
        model_type,
        source_file,
        target_file,
        test_source_file,
        test_target_file,
        dir,
        evaluate_every=10,
        num_sentences=400,
        num_sentences_test=500,
        reuseCalculatedTranslations=False,
        reuseInitialTranslations=False,
        initialTranslationFile="",
        initialScoreFile="",
        initialTestTranslationFile="",
        translationFile="",
        batch_translate=True,
    ):

        self.model = model
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang
        self.model_type = model_type
        self.source_file = source_file
        self.target_file = target_file
        self.loader = LanguagePairLoader(src_lang, tgt_lang, source_file,
                                         target_file)
        self.test_loader = LanguagePairLoader(src_lang, tgt_lang,
                                              test_source_file,
                                              test_target_file)

        self.extractor = DomainSpecificExtractor(
            source_file=source_file,
            src_lang=src_lang,
            tgt_lang=tgt_lang,
            train_source_file=
            f".data/wmt14/train.tok.clean.bpe.32000.{src_lang}",
            train_vocab_file=f".data/vocab/train_vocab_{src_lang}.pkl")

        self.target_extractor = DomainSpecificExtractor(
            source_file=target_file,
            src_lang=tgt_lang,
            tgt_lang=src_lang,
            train_source_file=
            f".data/wmt14/train.tok.clean.bpe.32000.{tgt_lang}",
            train_vocab_file=f".data/vocab/train_vocab_{tgt_lang}.pkl")

        self.scorer = Scorer()
        self.scores = {}
        self.num_sentences = num_sentences
        self.num_sentences_test = num_sentences_test
        self.batch_translate = batch_translate
        self.evaluate_every = evaluate_every
        self.reuseCalculatedTranslations = reuseCalculatedTranslations
        self.reuseInitialTranslations = reuseInitialTranslations

        self.initialTranslationFile = initialTranslationFile
        self.initialScoreFile = initialScoreFile
        self.initialTestTranslationFile = initialTestTranslationFile
        self.translationFile = translationFile

        self.metric_bleu_scores = {}
        self.metric_gleu_scores = {}
        self.metric_precisions = {}
        self.metric_recalls = {}

        self.prefix = "_experiments/retrain_beam3"
        self.dir = dir

    def save_data(self):
        prefix = ("batch_" if self.batch_translate else "beam_") + str(
            self.evaluate_every) + "_"
        prefix = os.path.join(self.dir, prefix)
        pickle.dump(self.metric_bleu_scores,
                    open(prefix + "metric_bleu_scores.pkl", "wb"))
        pickle.dump(self.metric_gleu_scores,
                    open(prefix + "metric_gleu_scores.pkl", "wb"))
        pickle.dump(self.metric_precisions,
                    open(prefix + "metric_precisions.pkl", "wb"))
        pickle.dump(self.metric_recalls,
                    open(prefix + "metric_recalls.pkl", "wb"))
        print("Saved all scores")

    def save_translation(self, translation, metric, step):
        name = os.path.join(self.dir,
                            metric + "_" + str(step) + self.translationFile)
        pickle.dump(translation, open(name, "wb"))
        print("Saved: " + name)

    def restore_translation(self, metric, step):
        name = os.path.join(self.prefix,
                            metric + "_" + str(step) + self.translationFile)
        with open(name, 'rb') as f:
            return pickle.load(f)

    def save_initialTranslation(self, scores, translations):
        name = os.path.join(self.dir, self.initialTranslationFile)
        pickle.dump(translations, open(name, "wb"))
        name = os.path.join(self.dir, self.initialScoreFile)
        pickle.dump(scores, open(name, "wb"))
        print("Saved: " + name)

    def restore_initialTranslation(self):
        name = os.path.join(self.prefix, self.initialTranslationFile)
        with open(name, 'rb') as f:
            translations = pickle.load(f)
        name = os.path.join(self.prefix, self.initialScoreFile)
        with open(name, 'rb') as f:
            scores = pickle.load(f)
        return translations, scores

    def save_initialTestTranslation(self, translations):
        name = os.path.join(self.dir, self.initialTestTranslationFile)
        pickle.dump(translations, open(name, "wb"))
        print("Saved: " + name)

    def restore_initialTestTranslation(self):
        name = os.path.join(self.prefix, self.initialTestTranslationFile)
        with open(name, 'rb') as f:
            return pickle.load(f)

    def run(self):
        _, _, pairs = self.loader.load()
        random.shuffle(pairs)

        pairs = pairs[:self.num_sentences]
        sources, targets, translations = [p[0] for p in pairs
                                          ], [p[1] for p in pairs], []

        keyphrases = self.extractor.extract_keyphrases(n_results=100)

        target_keyphrases = self.target_extractor.extract_keyphrases(
            n_results=100)

        # translation and scores for order of retraining
        print('Translating ...')
        if not reuseCalculatedTranslations and not reuseInitialTranslations:
            for i, pair in enumerate(tqdm(pairs)):
                translation, attn, _ = self.model.translate(pair[0])
                translations.append(" ".join(translation[:-1]))

                metrics_scores = self.scorer.compute_scores(
                    pair[0], " ".join(translation[:-1]), attn, keyphrases, "")
                for metric in metrics_scores:
                    if metric not in self.scores:
                        self.scores[metric] = []
                    self.scores[metric].append(metrics_scores[metric])
            self.save_initialTranslation(self.scores, translations)
        else:
            translations, self.scores = self.restore_initialTranslation()

        # initial test set translation
        _, _, test_pairs = self.test_loader.load()
        test_pairs = test_pairs[:self.num_sentences_test]
        test_sources, test_targets, test_translations = [
            p[0] for p in test_pairs
        ], [p[1] for p in test_pairs], []

        if not reuseCalculatedTranslations and not reuseInitialTranslations:
            print('- not reusing translations: Translating...')
            for i, source in enumerate(tqdm(test_sources)):
                translation, attn, _ = self.model.translate(source)
                test_translations.append(" ".join(translation[:-1]))

            if self.batch_translate:
                test_translations = [
                    t[:-6] for t in self.model.batch_translate(test_sources)
                ]

            self.save_initialTestTranslation(test_translations)
        else:
            test_translations = self.restore_initialTestTranslation()

        metrics = [
            "random", "keyphrase_score", "coverage_penalty", "confidence",
            "length"
        ]

        print("Evaluating metrics...")
        for i, metric in enumerate(tqdm(metrics)):
            self.metric_bleu_scores[metric] = []
            self.metric_gleu_scores[metric] = []
            self.metric_precisions[metric] = []
            self.metric_recalls[metric] = []

            sourcesCopy = sources[:]
            targetsCopy = targets[:]
            translationsCopy = translations[:]

            self.evaluate_metric(
                self.src_lang,
                self.tgt_lang,
                self.model_type,
                sourcesCopy,
                targetsCopy,
                translationsCopy,
                self.scores[metric] if metric != "random" else [],
                metric,
                target_keyphrases,
                test_sources,
                test_targets,
                test_translations,
                need_sort=True if metric != "random" else False,
                reverse=reverse_sort_direction[metric]
                if metric != "random" else True)
            print()
            print(self.metric_bleu_scores)
            self.save_data()

    def shuffle_list(self, *ls):
        l = list(zip(*ls))

        random.shuffle(l)
        return zip(*l)

    def evaluate_metric(self,
                        src_lang,
                        tgt_lang,
                        model_type,
                        sources,
                        targets,
                        translations,
                        scores,
                        metric,
                        target_keyphrases,
                        test_sources,
                        test_targets,
                        test_translations,
                        need_sort=True,
                        reverse=False):
        print()
        print("Evaluating {}".format(metric))
        base_bleu = compute_bleu(targets, translations)
        print("Base BLEU (of retraining data): {}".format(base_bleu))

        # Sort by metric
        if need_sort:
            sorted_sentences = [(x, y, z) for _, x, y, z in sorted(
                zip(scores, sources, targets, translations), reverse=reverse)]
            sources, targets, translations = zip(*sorted_sentences)
        else:
            sources, targets, translations = self.shuffle_list(
                sources, targets, translations)

        n = len(sources)
        encoder_optimizer_state, decoder_optimizer_state = None, None

        pretraining_bleu = compute_bleu(test_targets, test_translations)
        pretraining_gleu = compute_gleu(test_targets, test_translations)
        print()
        print("pretraining BLEU of test set (before retraining)")
        print(pretraining_bleu)

        prerecall = unigram_recall(target_keyphrases, test_targets,
                                   test_translations)
        preprecision = unigram_precision(target_keyphrases, test_targets,
                                         test_translations)

        self.metric_bleu_scores[metric].append(
            (pretraining_bleu, pretraining_bleu))
        self.metric_gleu_scores[metric].append(
            (pretraining_gleu, pretraining_gleu))
        self.metric_recalls[metric].append((prerecall, prerecall))
        self.metric_precisions[metric].append((preprecision, preprecision))
        self.save_data()

        if isinstance(self.model, TransformerTranslator):
            # create a new checkpoint here that gets overwritten with each ij
            # Neccessary to load trainer state.
            current_ckpt = f'.data/models/transformer/trafo_{src_lang}_{tgt_lang}_ensemble.pt'

        print('Training...')
        for i in tqdm(range(0, n)):
            # retranslate only every 10th sentence
            # evaluets for the 0th, 10th, 20th, ... sentence -> computes for sentences (0..9), (10..19), (20..29);
            # first sentence i = 0; evaluate_every = 10
            if i % self.evaluate_every != 0:
                continue

            if not reuseCalculatedTranslations:

                # Now train, and compute BLEU again
                start = i
                end = min(i + self.evaluate_every, n)

                print()
                print("Correcting {} - {} of {} sentences".format(
                    start, end - 1, n))

                if isinstance(self.model, Seq2SeqModel):
                    # same parameters that are used in the tool
                    encoder_optimizer_state, decoder_optimizer_state = retrain_iters(
                        self.model, [[x, y] for x, y in zip(
                            sources[start:end], targets[start:end])], [],
                        src_lang,
                        tgt_lang,
                        batch_size=1,
                        encoder_optimizer_state=encoder_optimizer_state,
                        decoder_optimizer_state=decoder_optimizer_state,
                        print_every=1,
                        n_epochs=15,
                        learning_rate=0.0001,
                        save_ckpt=i == n - 1)
                else:
                    # same parameters that are used in the tool
                    current_ckpt = self.model.retrain(
                        src_lang,
                        tgt_lang, [[x, y] for x, y in zip(
                            sources[start:end], targets[start:end])],
                        last_ckpt=current_ckpt,
                        epochs=15,
                        batch_size=1,
                        device=DEVICE,
                        save_ckpt=i == n - 1,
                        print_info=False)

                corrected_translations = []

                print(' - Translate using trained model')
                if not self.batch_translate:
                    # Translate trained model
                    for j in tqdm(range(0, len(test_sources))):
                        translation, _, _ = self.model.translate(
                            test_sources[j])
                        corrected_translations.append(" ".join(
                            translation[:-1]))
                else:
                    batch_translations = self.model.batch_translate(
                        test_sources)
                    corrected_translations = [
                        t[:-6] for t in batch_translations
                    ]

                self.save_translation(corrected_translations, metric, i)

            else:
                corrected_translations = self.restore_translation(metric, i)

            # Compute posttraining BLEU
            posttraining_bleu = compute_bleu(test_targets,
                                             corrected_translations)
            posttraining_gleu = compute_gleu(test_targets,
                                             corrected_translations)
            postrecall = unigram_recall(target_keyphrases, test_targets,
                                        corrected_translations)
            postprecision = unigram_precision(target_keyphrases, test_targets,
                                              corrected_translations)
            print("(Base BLEU {})".format(base_bleu))
            print("Delta Recall {} -> {}".format(prerecall, postrecall))
            print("Delta Precision {} -> {}".format(preprecision,
                                                    postprecision))
            print("Delta GLEU: {} -> {}".format(pretraining_gleu,
                                                posttraining_gleu))
            print("Delta BLEU: {} -> {}".format(pretraining_bleu,
                                                posttraining_bleu))

            delta_bleu = posttraining_bleu - pretraining_bleu
            print("Delta: {}".format(delta_bleu))

            self.metric_bleu_scores[metric].append(
                (pretraining_bleu, posttraining_bleu))
            self.metric_gleu_scores[metric].append(
                (pretraining_gleu, posttraining_gleu))
            self.metric_recalls[metric].append((prerecall, postrecall))
            self.metric_precisions[metric].append(
                (preprecision, postprecision))

            self.save_data()

        self.model = load_model(src_lang, tgt_lang, model_type,
                                device=DEVICE)  # reload initial model
        return None
Пример #10
0
class AveragedMetricExperiment:
    def __init__(self, model, source_file, target_file, raw_source_file, raw_target_file, num_sentences=400):
        self.model = model
        self.source_file = source_file
        self.target_file = target_file
        self.loader = LanguagePairLoader("de", "en", source_file, target_file)
        self.extractor = DomainSpecificExtractor(source_file=raw_source_file, train_source_file=hp.source_file,
                                                 train_vocab_file="train_vocab.pkl")
        self.target_extractor = DomainSpecificExtractor(source_file=raw_target_file, train_source_file=hp.target_file,
                                                        train_vocab_file="train_vocab_en.pkl")
        self.scorer = Scorer()
        self.scores = {}
        self.num_sentences = num_sentences

        self.metric_bleu_scores = {}
        self.metric_gleu_scores = {}
        self.metric_precisions = {}
        self.metric_recalls = {}
        self.cer = {}

        # Plot each metric
        plt.style.use('seaborn-darkgrid')
        self.palette = sns.color_palette()

    def save_data(self):
        prefix = "averaged_"
        pickle.dump(self.metric_bleu_scores, open(prefix + "metric_bleu_scores.pkl", "wb"))
        pickle.dump(self.metric_gleu_scores, open(prefix + "metric_gleu_scores.pkl", "wb"))
        pickle.dump(self.metric_precisions, open(prefix + "metric_precisions.pkl", "wb"))
        pickle.dump(self.metric_recalls, open(prefix + "metric_recalls.pkl", "wb"))
        pickle.dump(self.cer, open(prefix + "metric_cer.pkl", "wb"))
        print("Saved all scores")

    def run(self):
        _, _, pairs = self.loader.load()
        random.seed(2018)
        random.shuffle(pairs)

        pairs = pairs[:self.num_sentences]

        sources, targets, translations = [p[0] for p in pairs], [p[1] for p in pairs], []

        keyphrases = self.extractor.extract_keyphrases(n_results=100)
        print(keyphrases)
        target_keyphrases = self.target_extractor.extract_keyphrases(n_results=100)
        print(target_keyphrases)

        for i, pair in enumerate(pairs):
            if i % 10 == 0:
                print("Translated {} of {}".format(i, len(pairs)))
            translation, attn, _ = self.model.translate(pair[0])
            translations.append(" ".join(translation[:-1]))

            metrics_scores = self.scorer.compute_scores(pair[0], " ".join(translation[:-1]), attn, keyphrases)
            for metric in metrics_scores:
                if metric not in self.scores:
                    self.scores[metric] = []
                self.scores[metric].append(metrics_scores[metric])

        metrics = [
            # "coverage_penalty",
            # "coverage_deviation_penalty",
            # "confidence",
            # "length",
            # "ap_in",
            # "ap_out",
            # "random",
            "keyphrase_score"
        ]
        n_iters = 1
        for i, metric in enumerate(metrics):
            avg_bleus = [0 for _ in range(1, 100 // (step_size * 2) + 1)]
            self.metric_bleu_scores[metric] = []
            self.metric_gleu_scores[metric] = []
            self.metric_precisions[metric] = []
            self.metric_recalls[metric] = []
            self.cer[metric] = []
            for j in range(n_iters):
                self.evaluate_metric(sources, targets, translations,
                                     self.scores[metric] if metric != "random" else [],
                                     metric,
                                     target_keyphrases,
                                     need_sort=True if metric != "random" else False,
                                     reverse=sort_direction[metric] if metric != "random" else True)

                # plt.plot(x, delta_bleus, marker='', linestyle="--", color=self.palette[i], linewidth=1, alpha=0.9,
                #        label=metric)
            self.save_data()

    def shuffle_list(self, *ls):
        l = list(zip(*ls))

        random.shuffle(l)
        return zip(*l)

    def evaluate_metric(self, sources, targets, translations, scores, metric, target_keyphrases,
                        need_sort=True,
                        reverse=False):
        print("Evaluating {}".format(metric))
        base_bleu = compute_bleu(targets, translations)
        print("Base BLEU: {}".format(base_bleu))
        # Sort by metric
        if need_sort:
            sorted_sentences = [(x, y, z) for _, x, y, z in
                                sorted(zip(scores, sources, targets, translations), reverse=reverse)]
            sources, targets, translations = zip(*sorted_sentences)
        else:
            sources, targets, translations = self.shuffle_list(sources, targets, translations)

        n = len(sources)
        encoder_optimizer_state, decoder_optimizer_state = None, None

        corrected_translations = []

        cer_improvement = []
        curr_cer = 0

        for i in range(1, n + 1):
            print()
            print("{}: Correcting {} of {} sentences".format(metric, i, n))

            curr_end = i

            # Compute BLEU before training for comparison
            pretraining_bleu = compute_bleu(targets[:curr_end], translations[:curr_end])
            pretraining_gleu = compute_gleu(targets[:curr_end], translations[:curr_end])
            prerecall = unigram_recall(target_keyphrases, targets[:curr_end], translations[:curr_end])
            preprecision = unigram_precision(target_keyphrases, targets[:curr_end], translations[:curr_end])

            precer = cer(targets[i - 1].replace("@@ ", "").split(), translations[i - 1].replace("@@ ", "").split())

            translation, _, _ = seq2seq_model.translate(sources[i - 1])
            corrected_translations.append(" ".join(translation[:-1]))

            postcer = cer(targets[i - 1].replace("@@ ", "").split(),
                          " ".join(translation[:-1]).replace("@@ ", "").split())
            curr_cer = precer - postcer
            cer_improvement.append(curr_cer)

            # Compute posttraining BLEU
            posttraining_bleu = compute_bleu(targets[:curr_end], corrected_translations)
            posttraining_gleu = compute_gleu(targets[:curr_end], corrected_translations)

            postrecall = unigram_recall(target_keyphrases, targets[:curr_end], corrected_translations)
            postprecision = unigram_precision(target_keyphrases, targets[:curr_end], corrected_translations)
            print("Delta Recall {} -> {}".format(prerecall, postrecall))
            print("Delta Precision {} -> {}".format(preprecision, postprecision))
            print("Delta BLEU: {} -> {}".format(pretraining_bleu, posttraining_bleu))
            print("Delta CER: {} -> {}".format(precer, postcer))

            self.metric_bleu_scores[metric].append((pretraining_bleu, posttraining_bleu))
            self.metric_gleu_scores[metric].append((pretraining_gleu, posttraining_gleu))
            self.metric_recalls[metric].append((prerecall, postrecall))
            self.metric_precisions[metric].append((preprecision, postprecision))

            # Now train, and compute BLEU again
            encoder_optimizer_state, decoder_optimizer_state = retrain_iters(self.model,
                                                                             [[sources[i - 1],
                                                                               targets[i - 1]]], [],
                                                                             batch_size=1,
                                                                             encoder_optimizer_state=encoder_optimizer_state,
                                                                             decoder_optimizer_state=decoder_optimizer_state,
                                                                             n_epochs=1, learning_rate=0.00005,
                                                                             weight_decay=1e-3)

        self.cer[metric] = cer_improvement
        reload_model(self.model)
        return None

    def plot(self):
        plt.xlabel('% Corrected Sentences')
        plt.ylabel('Δ BLEU')
        # Add titles
        plt.title("BLEU Change for Metrics", loc='center', fontsize=12, fontweight=0)
        # Add legend
        plt.legend(loc='lower right', ncol=1)
        plt.savefig('bleu_deltas.png')
Пример #11
0
def documentUpload():
    if 'file' not in request.files:
        return redirect(request.url)
    file = request.files['file']
    # if user does not select file, browser also
    # submit an empty part without filename
    if file.filename == '':
        return redirect(request.url)
    if file and allowed_file(file.filename):
        document_name = request.args.get("document_name")
        id = uuid4()
        filename = secure_filename(file.filename)
        filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
        file.save(filepath)

        user = User.query.filter_by(username=get_jwt_identity()).first()
        dbDocument = DBDocument(id=id, name=document_name, user=user, model=model_name)

        document = Document(str(id), document_name, dict(), filepath)
        sentences = document.load_content(filename)
        sentences = list(filter(None, sentences))  # remove empty lines

        with open(filepath, "w", encoding='utf-8') as f:
            for i, sentence in enumerate(sentences):
                f.write(sentence.replace("@@ ", "") + "\n" if i < len(sentences) - 1 else "")

        extractor = DomainSpecificExtractor(source_file=filepath, src_lang=SRC_LANG, tgt_lang=TGT_LANG,
                                            train_source_file=f".data/wmt14/train.tok.clean.bpe.32000.{SRC_LANG}",
                                            train_vocab_file=f".data/vocab/train_vocab_{SRC_LANG}.pkl")
        keyphrases = extractor.extract_keyphrases(n_results=30)

        scorer = Scorer()

        print("Translating {} sentences".format(len(sentences)))

        beamSize = 3
        attLayer = -2
        for i, source in enumerate(sentences):
            translation, attn, translations = model.translate(source, beam_size=beamSize,  attLayer=attLayer, beam_length=0.6,
                                                                      beam_coverage=0.4)
            print("Translated {} of {}".format(i + 1, len(sentences)))

            beam = translationsToTree(translations[:beamSize])

            # print("  ", translation)
            score = scorer.compute_scores(source, " ".join(translation), attn, keyphrases, "")
            score["order_id"] = i
            sentence = Sentence(i, source, " ".join(translation), attn, beam, score)

            document.sentences.append(sentence)

        print("Finished translation")

        keyphrases = [{"name": k, "occurrences": f, "active": False} for (k, f) in keyphrases]
        document.keyphrases = keyphrases
        db.session.add(dbDocument)
        db.session.commit()

        save_document(document, id)

        return jsonify({})
    return jsonify({})