def validate(model,
             dev_data,
             vocab_src,
             vocab_tgt,
             epoch,
             config,
             direction=None):
    model.eval()
    device = torch.device(
        "cpu") if config["device"] == "cpu" else torch.device("cuda:0")
    with torch.no_grad():
        model_hypotheses = []
        references = []

        val_dl = DataLoader(dev_data,
                            batch_size=config["batch_size_eval"],
                            shuffle=False,
                            num_workers=4)
        val_dl = BucketingParallelDataLoader(val_dl)
        for sentences_x, sentences_y in val_dl:
            if direction == None or direction == "xy":
                x_in, _, x_mask, x_len = create_batch(sentences_x, vocab_src,
                                                      device)
                x_mask = x_mask.unsqueeze(1)
            else:
                x_in, _, x_mask, x_len = create_batch(sentences_y, vocab_src,
                                                      device)
                x_mask = x_mask.unsqueeze(1)

            enc_output, enc_hidden = model.encode(x_in, x_len)
            dec_hidden = model.init_decoder(enc_output, enc_hidden)

            raw_hypothesis = beam_search(model.decoder, model.emb_tgt,
                                         model.generate_tm, enc_output,
                                         dec_hidden, x_mask, vocab_tgt.size(),
                                         vocab_tgt[SOS_TOKEN],
                                         vocab_tgt[EOS_TOKEN],
                                         vocab_tgt[PAD_TOKEN], config)

            hypothesis = batch_to_sentences(raw_hypothesis, vocab_tgt)
            model_hypotheses += hypothesis.tolist()

            if direction == None or direction == "xy":
                references += sentences_y.tolist()
            else:
                references += sentences_x.tolist()

        save_hypotheses(model_hypotheses, epoch, config)
        model_hypotheses, references = clean_sentences(model_hypotheses,
                                                       references, config)
        bleu = compute_bleu(model_hypotheses, references, epoch, config,
                            direction)
        return bleu
    (
        'https://master.dl.sourceforge.net/project/autshumato/Corpora/ENG-NSO.Release.zip',  # url
        'lcontent.DACB.DataVirVrystellingOpWeb.(eng-nso).nso.1.0.0.CAM.2010-09-23.txt',  # file_name
        'sepedi/sepedi.txt',  # output_name
        5301,  # lines_to_remove (constitution with poor formatting)
    )
]

for url, file_name, output_name, lines_to_remove in datasets:
    print('processing:', url)
    r = requests.get(url)
    zip = zipfile.ZipFile(BytesIO(r.content))
    corpus = zip.open(file_name)
    corpus = corpus.read()
    corpus = corpus.decode('utf-8')

    sentences = corpus.splitlines()

    sentences = utils.clean_sentences(sentences, min_length=3, lines_to_remove=lines_to_remove)

    output_file_name = os.path.join(args.output_dir, output_name)
    with open(output_file_name, 'w', encoding='utf-8') as f:
        f.write('\n'.join(sentences))

    print('total sentences in {}:'.format(output_name), len(sentences))

print('Autshumato datasets provided under Creative Commons Attribution Non-Commercial ShareAlike, '
      'CTexT (Centre for Text Technology, North-West University), South Africa; '
      'Department of Arts and Culture, South Africa. '
      'http://autshumato.sourceforge.net/ and http://www.nwu.ac.za/ctext')
Esempio n. 3
0
            article = re.sub(r'\(\d*\)', '', article)

            # remove extra whitespace
            article = re.sub('\\s+', ' ', article)

            # replace strange quote character
            article = utils.normalize_quote_characters(article)

            # discard articles with imbalanced quotes
            if article.count('"') % 2 != 0:
                continue

            # split article into array of sentences
            # regex help from https://stackoverflow.com/questions/11502598/how-to-match-something-with-regex-that-is-not-between-two-special-characters
            sentences = re.split('(?<=\.|\!|\?) (?=(?:[^"]*"[^"]*")*[^"]*\Z)',
                                 article)

            corpus = corpus + sentences

corpus = utils.clean_sentences(corpus,
                               min_length=16,
                               illegal_substrings=['@', '%2'])

print('total sentences:', len(corpus))

with open(os.path.join(args.output_dir + '/isizulu/',
                       os.path.basename("isizulu.txt")),
          'w',
          encoding='utf-8') as f:
    f.write('\n'.join(corpus) + '\n')
def validate(model,
             dev_data,
             vocab_src,
             vocab_tgt,
             epoch,
             config,
             direction=None):
    model.eval()
    device = torch.device(
        "cpu") if config["device"] == "cpu" else torch.device("cuda:0")
    with torch.no_grad():
        model_hypotheses = []
        references = []

        val_dl = DataLoader(dev_data,
                            batch_size=config["batch_size_eval"],
                            shuffle=False,
                            num_workers=2)
        val_dl = BucketingParallelDataLoader(val_dl)
        val_kl = 0
        for sentences_x, sentences_y in val_dl:
            if direction == None or direction == "xy":
                x_in, _, x_mask, x_len = create_batch(sentences_x, vocab_src,
                                                      device)
                x_mask = x_mask.unsqueeze(1)
            else:
                x_in, _, x_mask, x_len = create_batch(sentences_y, vocab_src,
                                                      device)
                x_mask = x_mask.unsqueeze(1)

            qz = model.inference(x_in, x_mask, x_len)
            z = qz.mean

            pz = torch.distributions.Normal(loc=model.prior_loc,
                                            scale=model.prior_scale).expand(
                                                qz.mean.size())
            kl_loss = torch.distributions.kl.kl_divergence(qz, pz)
            kl_loss = kl_loss.sum(dim=1)
            val_kl += kl_loss.sum(dim=0)

            enc_output, enc_hidden = model.encode(x_in, x_len, z)
            dec_hidden = model.init_decoder(enc_output, enc_hidden, z)

            raw_hypothesis = beam_search(model.decoder, model.emb_tgt,
                                         model.generate_tm, enc_output,
                                         dec_hidden, x_mask, vocab_tgt.size(),
                                         vocab_tgt[SOS_TOKEN],
                                         vocab_tgt[EOS_TOKEN],
                                         vocab_tgt[PAD_TOKEN], config, z)

            hypothesis = batch_to_sentences(raw_hypothesis, vocab_tgt)
            model_hypotheses += hypothesis.tolist()

            if direction == None or direction == "xy":
                references += sentences_y.tolist()
            else:
                references += sentences_x.tolist()

        val_kl /= len(dev_data)
        save_hypotheses(model_hypotheses, epoch, config, direction)
        model_hypotheses, references = clean_sentences(model_hypotheses,
                                                       references, config)
        bleu = compute_bleu(model_hypotheses,
                            references,
                            epoch,
                            config,
                            direction,
                            kl=val_kl)
        return bleu
def main():
    config = setup_config()
    config["dev_prefix"] = "comparable"
    vocab_src, vocab_tgt = load_vocabularies(config)
    _, dev_data, _ = load_data(config,
                               vocab_src=vocab_src,
                               vocab_tgt=vocab_tgt)

    model, _, validate_fn = create_model(vocab_src, vocab_tgt, config)
    model.to(torch.device(config["device"]))

    checkpoint_path = "{}/cond_nmt_de-en_run_7/checkpoints/cond_nmt_de-en_run_7".format(
        config["out_dir"])

    state = torch.load(checkpoint_path)
    model.load_state_dict(state['state_dict'])

    model.eval()
    device = torch.device(
        "cpu") if config["device"] == "cpu" else torch.device("cuda:0")
    with torch.no_grad():
        model_hypotheses = []
        references = []

        val_dl = DataLoader(dev_data,
                            batch_size=config["batch_size_eval"],
                            shuffle=False,
                            num_workers=4)
        # val_dl = BucketingParallelDataLoader(val_dl)
        for sentences_x, sentences_y in tqdm(val_dl):

            sentences_x = np.array(sentences_x)
            seq_len = np.array([len(s.split()) for s in sentences_x])
            sort_keys = np.argsort(-seq_len)
            sentences_x = sentences_x[sort_keys]
            # #
            sentences_y = np.array(sentences_y)

            x_in, _, x_mask, x_len = create_batch(sentences_x, vocab_src,
                                                  device)
            x_mask = x_mask.unsqueeze(1)

            if config["model_type"] == "aevnmt":
                qz = model.inference(x_in, x_mask, x_len)
                z = qz.mean

                enc_output, enc_hidden = model.encode(x_in, x_len, z)
                dec_hidden = model.init_decoder(enc_output, enc_hidden, z)

                raw_hypothesis = beam_search(model.decoder, model.emb_tgt,
                                             model.generate_tm, enc_output,
                                             dec_hidden, x_mask,
                                             vocab_tgt.size(),
                                             vocab_tgt[SOS_TOKEN],
                                             vocab_tgt[EOS_TOKEN],
                                             vocab_tgt[PAD_TOKEN], config)
            else:
                enc_output, enc_hidden = model.encode(x_in, x_len)
                dec_hidden = model.decoder.initialize(enc_output, enc_hidden)

                raw_hypothesis = beam_search(model.decoder, model.emb_tgt,
                                             model.generate_tm, enc_output,
                                             dec_hidden, x_mask,
                                             vocab_tgt.size(),
                                             vocab_tgt[SOS_TOKEN],
                                             vocab_tgt[EOS_TOKEN],
                                             vocab_tgt[PAD_TOKEN], config)

            hypothesis = batch_to_sentences(raw_hypothesis, vocab_tgt)

            inverse_sort_keys = np.argsort(sort_keys)
            model_hypotheses += hypothesis[inverse_sort_keys].tolist()

            references += sentences_y.tolist()
        save_hypotheses(model_hypotheses, 0, config, None)
        model_hypotheses, references = clean_sentences(model_hypotheses,
                                                       references, config)
        bleu = sacrebleu.raw_corpus_bleu(model_hypotheses, [references]).score
        print(bleu)
Esempio n. 6
0
]

for url, file_name, output_name, lines_to_remove in datasets:
    print('processing:', url)

    r = requests.get(url)
    zip = zipfile.ZipFile(BytesIO(r.content))
    corpus = zip.open(file_name).read().decode('utf-8').strip()

    # remove tags containing article filenames
    corpus = re.sub(r'<fn>.*</fn>', '', corpus)

    # put each sentence on a new line
    corpus = corpus.replace('. ', '.\n')

    # remove empty lines from corpus
    sentences = corpus.splitlines()

    sentences = utils.clean_sentences(
        sentences,
        illegal_substrings=['\ufeff', '='],
        lines_to_remove=lines_to_remove,
    )

    # write article to file (with each sentence on a new line)
    output_file_name = os.path.join(args.output_dir, output_name)
    with open(output_file_name, 'w', encoding='utf-8') as f:
        f.write('\n'.join(sentences))

    print('total sentences in {}:'.format(output_name), corpus.count('\n'))
Esempio n. 7
0
def evaluate(model, dev_data, vocab_src, vocab_tgt, config, direction=None):
    model.eval()
    with torch.no_grad():
        model_hypotheses = []
        references = []

        device = torch.device(
            "cpu") if config["device"] == "cpu" else torch.device("cuda:0")
        val_dl = DataLoader(dev_data,
                            batch_size=config["batch_size_eval"],
                            shuffle=False,
                            num_workers=4)
        # val_dl = BucketingParallelDataLoader(val_dl)
        for sentences_x, sentences_y in tqdm(val_dl):
            if direction == None or direction == "xy":
                sentences_x, sentences_y, sort_keys = sort_sentences(
                    sentences_x, sentences_y)
                x_in, _, x_mask, x_len = create_batch(sentences_x, vocab_src,
                                                      device)
                x_mask = x_mask.unsqueeze(1)
            else:
                sentences_y, sentences_x, sort_keys = sort_sentences(
                    sentences_y, sentences_x)
                x_in, _, x_mask, x_len = create_batch(sentences_y, vocab_src,
                                                      device)
                x_mask = x_mask.unsqueeze(1)

            if config["model_type"] == "coaevnmt":
                qz = model.inference(x_in, x_mask, x_len)
                z = qz.mean

                enc_output, enc_hidden = model.encode(x_in, x_len, z)
                dec_hidden = model.init_decoder(enc_output, enc_hidden, z)

                raw_hypothesis = beam_search(model.decoder,
                                             model.emb_tgt,
                                             model.generate_tm,
                                             enc_output,
                                             dec_hidden,
                                             x_mask,
                                             vocab_tgt.size(),
                                             vocab_tgt[SOS_TOKEN],
                                             vocab_tgt[EOS_TOKEN],
                                             vocab_tgt[PAD_TOKEN],
                                             config,
                                             z=z)
            elif config["model_type"] == "conmt":
                enc_output, enc_hidden = model.encode(x_in, x_len)
                dec_hidden = model.decoder.initialize(enc_output, enc_hidden)

                raw_hypothesis = beam_search(model.decoder, model.emb_tgt,
                                             model.generate_tm, enc_output,
                                             dec_hidden, x_mask,
                                             vocab_tgt.size(),
                                             vocab_tgt[SOS_TOKEN],
                                             vocab_tgt[EOS_TOKEN],
                                             vocab_tgt[PAD_TOKEN], config)

            hypothesis = batch_to_sentences(raw_hypothesis, vocab_tgt)
            inverse_sort_keys = np.argsort(sort_keys)
            model_hypotheses += hypothesis[inverse_sort_keys].tolist()

            if direction == None or direction == "xy":
                references += sentences_y.tolist()
            else:
                references += sentences_x.tolist()

        model_hypotheses, references = clean_sentences(model_hypotheses,
                                                       references, config)
        bleu = sacrebleu.raw_corpus_bleu(model_hypotheses, [references]).score
        print(bleu)