Exemple #1
0
def main():
    opt = parse_args()
    logger = get_logger(opt.log_file)
    logger.info("Extracting features...")

    src_nfeats = onmt.io.get_num_features(opt.data_type, opt.train_src, 'src')
    conversation_nfeats = onmt.io.get_num_features(opt.data_type,
                                                   opt.train_conv,
                                                   'conversation')
    tgt_nfeats = onmt.io.get_num_features(opt.data_type, opt.train_tgt, 'tgt')
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of conversation features: %d." %
                conversation_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    logger.info("Building `Fields` object...")
    fields = onmt.io.get_fields(opt.data_type, src_nfeats, conversation_nfeats,
                                tgt_nfeats)

    logger.info("Building & saving training data...")
    train_dataset_files = build_save_dataset('train', fields, opt, logger)

    logger.info("Building & saving vocabulary...")
    build_save_vocab(train_dataset_files, fields, opt, logger)

    logger.info("Building & saving validation data...")
    build_save_dataset('valid', fields, opt, logger)
Exemple #2
0
                                                          opt)
    logger.info("\nMatching: ")
    match_percent = [_['match'] / (_['match'] + _['miss']) * 100
                     for _ in [enc_count, dec_count]]
    logger.info("\t* enc: %d match, %d missing, (%.2f%%)"
                % (enc_count['match'],
                   enc_count['miss'],
                   match_percent[0]))
    logger.info("\t* dec: %d match, %d missing, (%.2f%%)"
                % (dec_count['match'],
                   dec_count['miss'],
                   match_percent[1]))

    logger.info("\nFiltered embeddings:")
    logger.info("\t* enc: ", filtered_enc_embeddings.size())
    logger.info("\t* dec: ", filtered_dec_embeddings.size())

    enc_output_file = opt.output_file + ".enc.pt"
    dec_output_file = opt.output_file + ".dec.pt"
    logger.info("\nSaving embedding as:\n\t* enc: %s\n\t* dec: %s"
                % (enc_output_file, dec_output_file))
    torch.save(filtered_enc_embeddings, enc_output_file)
    torch.save(filtered_dec_embeddings, dec_output_file)
    logger.info("\nDone.")
    """


if __name__ == "__main__":
    logger = get_logger('embeddings_to_torch.log')
    main()
Exemple #3
0
from __future__ import division, unicode_literals
import argparse

from onmt.translate.Translator import make_translator
from onmt.Utils import get_logger

import onmt.io
import onmt.translate
import onmt
import onmt.ModelConstructor
import onmt.modules
import onmt.opts


def main(opt):
    translator = make_translator(opt, report_score=True, logger=logger)
    translator.translate(opt.src_dir, opt.src, opt.tgt,
                         opt.batch_size, opt.attn_debug)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description='translate.py',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    onmt.opts.add_md_help_argument(parser)
    onmt.opts.translate_opts(parser)

    opt = parser.parse_args()
    logger = get_logger(opt.log_file)
    main(opt)
        pass
        if os.path.isdir(tmp_dir):
            shutil.rmtree(tmp_dir)


def rouge_results_to_str(results_dict):
    return ">> ROUGE(1/2/3/L/SU4): {:.2f}/{:.2f}/{:.2f}/{:.2f}/{:.2f}".format(
        results_dict["rouge_1_f_score"] * 100,
        results_dict["rouge_2_f_score"] * 100,
        results_dict["rouge_3_f_score"] * 100,
        results_dict["rouge_l_f_score"] * 100,
        results_dict["rouge_su*_f_score"] * 100)


if __name__ == "__main__":
    logger = get_logger('test_rouge.log')
    parser = argparse.ArgumentParser()
    parser.add_argument('-c',
                        type=str,
                        default="candidate.txt",
                        help='candidate file')
    parser.add_argument('-r',
                        type=str,
                        default="reference.txt",
                        help='reference file')
    args = parser.parse_args()
    if args.c.upper() == "STDIN":
        args.c = sys.stdin
    results_dict = test_rouge(args.c, args.r)
    logger.info(rouge_results_to_str(results_dict))
def main():
    opt = parse_args()
    logger = get_logger(opt.log_file)
    logger.info("Extracting features...")

    random.seed(13)

    # load the training data!
    if opt.dataset.lower() == 'e2e':
        parser = DatasetParser('data/e2e/trainset.csv', 'data/e2e/devset.csv',
                               'data/e2e/testset_w_refs.csv', 'E2E', opt)

        if opt.name.lower() != 'def':
            parser.dataset_name = opt.name.lower()
        for predicate in parser.trainingInstances:
            random.shuffle(parser.trainingInstances[predicate])
    elif opt.dataset.lower() == 'webnlg':
        parser = DatasetParser('data/webNLG_challenge_data/train',
                               'data/webNLG_challenge_data/dev', False,
                               'webNLG', opt)

        if opt.name.lower() != 'def':
            parser.dataset_name = opt.name.lower()
        for predicate in parser.trainingInstances:
            random.shuffle(parser.trainingInstances[predicate])
    elif opt.dataset.lower() == 'sfhotel':
        parser = DatasetParser('data/sfx_data/sfxhotel/train.json',
                               'data/sfx_data/sfxhotel/valid.json',
                               'data/sfx_data/sfxhotel/test.json', 'SFHotel',
                               opt)

        if opt.name.lower() != 'def':
            parser.dataset_name = opt.name.lower()
        for predicate in parser.trainingInstances:
            random.shuffle(parser.trainingInstances[predicate])

    gen_templ = parser.get_onmt_file_templ(opt)
    train_src_templ, train_tgt_templ, train_eval_refs_templ, valid_src_templ, valid_tgt_templ, valid_eval_refs_templ, test_src_templ, test_tgt_templ, test_eval_refs_templ = parser.get_onmt_file_templs(
        gen_templ)
    for predicate in parser.predicates:
        opt.file_templ = gen_templ.format(predicate)
        opt.train_src = train_src_templ.format(predicate)
        opt.train_tgt = train_tgt_templ.format(predicate)
        opt.valid_src = valid_src_templ.format(predicate)
        opt.valid_tgt = valid_tgt_templ.format(predicate)

        src_nfeats = onmt.io.get_num_features(opt.data_type, opt.train_src,
                                              'src')
        tgt_nfeats = onmt.io.get_num_features(opt.data_type, opt.train_tgt,
                                              'tgt')
        logger.info(" * number of source features: %d." % src_nfeats)
        logger.info(" * number of target features: %d." % tgt_nfeats)

        logger.info("Building {} `Fields` object...".format(predicate))
        fields = onmt.io.get_fields(opt.data_type, src_nfeats, tgt_nfeats)

        logger.info("Building & saving {} training data...".format(predicate))
        train_dataset_files = build_save_dataset('train', fields, opt, logger)

        logger.info("Building & saving {} vocabulary...".format(predicate))
        build_save_vocab(train_dataset_files, fields, opt, logger)

        logger.info(
            "Building & saving {} validation data...".format(predicate))
        build_save_dataset('valid', fields, opt, logger)

    if parser.trainingInstances:
        for predicate in parser.trainingInstances:
            print("Training data size for {}: {}".format(
                predicate, len(parser.trainingInstances[predicate])))
    if parser.developmentInstances:
        for predicate in parser.developmentInstances:
            print("Validation data size for {}: {}".format(
                predicate, len(parser.developmentInstances[predicate])))
    if parser.testingInstances:
        for predicate in parser.testingInstances:
            print("Test data size for {}: {}".format(
                predicate, len(parser.testingInstances[predicate])))
    print("-----------------------")
    def forward(self, input):
        laplacian = input.exp() + self.eps
        output = input.clone()
        for b in range(input.size(0)):
            lap = laplacian[b].masked_fill(
                torch.eye(input.size(1)).cuda().ne(0), 0)
            lap = -lap + torch.diag(lap.sum(0))
            # store roots on diagonal
            lap[0] = input[b].diag().exp()
            inv_laplacian = lap.inverse()

            factor = inv_laplacian.diag().unsqueeze(1)\
                                         .expand_as(input[b]).transpose(0, 1)
            term1 = input[b].exp().mul(factor).clone()
            term2 = input[b].exp().mul(inv_laplacian.transpose(0, 1)).clone()
            term1[:, 0] = 0
            term2[0] = 0
            output[b] = term1 - term2
            roots_output = input[b].diag().exp().mul(
                inv_laplacian.transpose(0, 1)[0])
            output[b] = output[b] + torch.diag(roots_output)
        return output


if __name__ == "__main__":
    logger = get_logger('StructuredAttention.log')
    dtree = MatrixTree()
    q = torch.rand(1, 5, 5).cuda()
    marg = dtree.forward(q)
    logger.info(marg.sum(1))
    fields = onmt.io.load_fields_from_vocab(checkpoint['vocab'])

    model_opt = checkpoint['opt']
    for arg in dummy_opt.__dict__:
        if arg not in model_opt:
            model_opt.__dict__[arg] = dummy_opt.__dict__[arg]

    model = onmt.ModelConstructor.make_base_model(model_opt, fields,
                                                  use_gpu(opt), checkpoint)
    encoder = model.encoder
    decoder = model.decoder

    encoder_embeddings = encoder.embeddings.word_lut.weight.data.tolist()
    decoder_embeddings = decoder.embeddings.word_lut.weight.data.tolist()

    logger.info("Writing source embeddings")
    write_embeddings(opt.output_dir + "/src_embeddings.txt", src_dict,
                     encoder_embeddings)

    logger.info("Writing target embeddings")
    write_embeddings(opt.output_dir + "/tgt_embeddings.txt", tgt_dict,
                     decoder_embeddings)

    logger.info('... done.')
    logger.info('Converting model...')


if __name__ == "__main__":
    logger = get_logger('extract_embeddings.log')
    main()