Пример #1
0
 def __init__(self, src_path, cfg_path, src_vocab_path, treelstm_vocab_path, cache_path=None,
              batch_size=64, max_tokens=80,
              part_index=0, part_num=1,
              load_data=True,
              truncate=None,
              limit_datapoints=None,
              limit_tree_depth=0):
     self._max_tokens = max_tokens
     self._src_path = src_path
     self._src_vocab_path = src_vocab_path
     self._cfg_path = cfg_path
     self._treelstm_vocab_path = treelstm_vocab_path
     self._src_vocab = Vocab(self._src_vocab_path)
     self._label_vocab = Vocab(self._treelstm_vocab_path)
     self._cache_path = cache_path
     self._truncate = truncate
     self._part_index = part_index
     self._part_num = part_num
     self._limit_datapoints = limit_datapoints
     self._limit_tree_depth = limit_tree_depth
     self._rand = np.random.RandomState(3)
     if load_data:
         train_data, valid_data = self._load_data()
     self._n_train_samples = len(train_data)
     super(BilingualTreeDataLoader, self).__init__(train_data=train_data, valid_data=valid_data, batch_size=batch_size)
Пример #2
0
 def __init__(self, cfg_path, treelstm_vocab_path, part_index=0, part_num=1, cache_path=None, limit_datapoints=None,
              limit_tree_depth=0):
     if cache_path is not None:
         self._cache_path = "{}.{}in{}".format(cache_path, part_index, part_num)
     else:
         self._cache_path = None
     self._cfg_path = cfg_path
     self._cfg_lines = None
     self._part_index = part_index
     self._part_num = part_num
     self._limit_datapoints = limit_datapoints
     self._limit_tree_depth = limit_tree_depth
     self._vocab = Vocab(treelstm_vocab_path, picklable=True)
     self._trees = []
Пример #3
0
    #    OPTS.batchtokens = 2048
    dataset = MTDataset(src_corpus=train_src_corpus,
                        tgt_corpus=tgt_corpus,
                        src_vocab=src_vocab_path,
                        tgt_vocab=tgt_vocab_path,
                        batch_size=OPTS.batchtokens * gpu_num,
                        batch_type="token",
                        truncate=truncate_datapoints,
                        max_length=TRAINING_MAX_TOKENS,
                        n_valid_samples=n_valid_samples)
else:
    dataset = None

# Create the model
basic_options = dict(dataset=dataset,
                     src_vocab_size=Vocab(src_vocab_path).size(),
                     tgt_vocab_size=Vocab(tgt_vocab_path).size(),
                     hidden_size=OPTS.hiddensz,
                     embed_size=OPTS.embedsz,
                     n_att_heads=OPTS.heads,
                     shard_size=OPTS.shard,
                     seed=OPTS.seed)

lanmt_options = basic_options.copy()
lanmt_options.update(
    dict(encoder_layers=5,
         prior_layers=OPTS.priorl,
         q_layers=OPTS.priorl,
         decoder_layers=OPTS.decoderl,
         latent_dim=OPTS.latentdim,
         KL_budget=0. if OPTS.finetune else OPTS.klbudget,
Пример #4
0
if OPTS.train:
    dataset = MTDataset(src_corpus=train_src_corpus,
                        tgt_corpus=tgt_corpus,
                        src_vocab=src_vocab_path,
                        tgt_vocab=tgt_vocab_path,
                        batch_size=OPTS.batchtokens * gpu_num,
                        batch_type="token",
                        truncate=truncate_datapoints,
                        max_length=TRAINING_MAX_TOKENS,
                        n_valid_samples=500)
else:
    dataset = None

# Create the model
basic_options = dict(dataset=dataset,
                     src_vocab_size=Vocab(src_vocab_path).size(),
                     tgt_vocab_size=Vocab(tgt_vocab_path).size(),
                     hidden_size=OPTS.hiddensz,
                     embed_size=OPTS.embedsz,
                     n_att_heads=OPTS.heads,
                     shard_size=OPTS.shard,
                     seed=OPTS.seed)

nmt = EnergyLanguageModel(latent_size=OPTS.latentdim)

# Training
if OPTS.train or OPTS.all:
    # Training code
    scheduler = SimpleScheduler(max_epoch=20)
    # scheduler = TransformerScheduler(warm_steps=training_warmsteps, max_steps=training_maxsteps)
    lr = 0.0001 * gpu_num / 8
Пример #5
0
    OPTS.batchtokens = 6144
    dataset = MTDataset(src_corpus=train_src_corpus,
                        tgt_corpus=tgt_corpus,
                        src_vocab=src_vocab_path,
                        tgt_vocab=tgt_vocab_path,
                        batch_size=OPTS.batchtokens * gpu_num,
                        batch_type="token",
                        truncate=truncate_datapoints,
                        max_length=TRAINING_MAX_TOKENS,
                        n_valid_samples=n_valid_samples)
else:
    dataset = None

# Create the model
basic_options = dict(dataset=dataset,
                     src_vocab_size=Vocab(src_vocab_path).size(),
                     tgt_vocab_size=Vocab(tgt_vocab_path).size(),
                     hidden_size=OPTS.hiddensz,
                     embed_size=OPTS.embedsz,
                     n_att_heads=OPTS.heads,
                     shard_size=OPTS.shard,
                     seed=OPTS.seed)

lanmt_options = basic_options.copy()
lanmt_options.update(
    dict(prior_layers=OPTS.priorl,
         decoder_layers=OPTS.decoderl,
         latent_dim=OPTS.latentdim,
         KL_budget=0. if OPTS.finetune else OPTS.klbudget,
         budget_annealing=OPTS.annealbudget,
         max_train_steps=training_maxsteps,
Пример #6
0
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import sys
sys.path.append("..")

from argparse import ArgumentParser
from nmtlab.utils import Vocab

if __name__ == '__main__':
    ap = ArgumentParser()
    ap.add_argument("-c", "--corpus", help="corpus path")
    ap.add_argument("-v", "--vocab", help="output path of the vocabulary")
    ap.add_argument("-l",
                    "--limit",
                    type=int,
                    default=0,
                    help="limit of maximum number of vocabulary items")
    args = ap.parse_args()

    vocab = Vocab()
    vocab.build(args.corpus, limit=args.limit)
    vocab.save(args.vocab)
Пример #7
0
                                pretrained_autoregressive_path)
 model_path = OPTS.model_path
 if not os.path.exists(model_path):
     print("Cannot find model in {}".format(model_path))
     sys.exit()
 # model_path = "{}/basemodel_wmt14_ende_x5longertrain_v2.pt.bak".format(DATA_ROOT)
 nmt.load(model_path)
 if torch.cuda.is_available():
     nmt.cuda()
 nmt.train(False)
 from nmtlab.decoding import BeamTranslator
 translator = BeamTranslator(nmt,
                             dataset.src_vocab(),
                             dataset.tgt_vocab(),
                             beam_size=OPTS.Tbeam)
 src_vocab = Vocab(src_vocab_path)
 tgt_vocab = Vocab(tgt_vocab_path)
 result_path = OPTS.result_path
 # Read data
 lines = open(test_src_corpus).readlines()
 latent_candidate_num = OPTS.Tcandidate_num if OPTS.Tlatent_search else None
 decode_times = []
 if OPTS.profile:
     lines = lines * 10
 # lines = lines[:100]
 # trains_stop_stdout_monitor()
 with open(OPTS.result_path, "w") as outf:
     for i, line in enumerate(lines):
         # Make a batch
         tokens = src_vocab.encode("<s> {} </s>".format(
             line.strip()).split())
Пример #8
0
class TreeDataGenerator(object):

    def __init__(self, cfg_path, treelstm_vocab_path, part_index=0, part_num=1, cache_path=None, limit_datapoints=None,
                 limit_tree_depth=0):
        if cache_path is not None:
            self._cache_path = "{}.{}in{}".format(cache_path, part_index, part_num)
        else:
            self._cache_path = None
        self._cfg_path = cfg_path
        self._cfg_lines = None
        self._part_index = part_index
        self._part_num = part_num
        self._limit_datapoints = limit_datapoints
        self._limit_tree_depth = limit_tree_depth
        self._vocab = Vocab(treelstm_vocab_path, picklable=True)
        self._trees = []

    def load(self):
        if not OPTS.smalldata and not OPTS.tinydata and self._cache_path is not None and os.path.exists(self._cache_path):
            print("loading cached trees part {} ...".format(self._part_index))
            self._trees = pickle.load(open(self._cache_path, "rb"))
            return
        self._cfg_lines = open(self._cfg_path).readlines()
        partition_size = int(len(self._cfg_lines) / self._part_num)
        self._cfg_lines = self._cfg_lines[self._part_index * partition_size: (self._part_index + 1) * partition_size]
        if self._limit_datapoints > 0:
            self._cfg_lines = self._cfg_lines[:self._limit_datapoints]
        print("building trees part {} ...".format(self._part_index))
        self._trees = self._build_batch_trees()
        if False and self._cache_path is not None:
            print("caching trees part {}...".format(self._part_index))
            pickle.dump(self._trees, open(self._cache_path, "wb"))

    def _parse_cfg_line(self, cfg_line):
        t = cfg_line.strip()
        # Replace leaves of the form (!), (,), with (! !), (, ,)
        t = re.sub(r"\((.)\)", r"(\1 \1)", t)
        # Replace leaves of the form (tag word root) with (tag word)
        t = re.sub(r"\(([^\s()]+) ([^\s()]+) [^\s()]+\)", r"(\1 \2)", t)
        try:
            tree = Tree.fromstring(t)
        except ValueError as e:
            tree = None
        return tree

    def _build_batch_trees(self):
        trees = []
        for line in self._cfg_lines:
            paired_tree = self.build_trees(line)
            trees.append(paired_tree)
        return trees

    def build_trees(self, cfg_line):
        parse = self._parse_cfg_line(cfg_line)
        if parse is None or not parse.leaves():
            return None
        enc_g = nx.DiGraph()
        dec_g = nx.DiGraph()
        failed = False

        def _rec_build(id_enc, id_dec, node, depth=0):
            if len(node) > 10:
                return
            if self._limit_tree_depth > 0 and depth >= self._limit_tree_depth:
                return
            # Skipp all terminals
            all_terminals = True
            for child in node:
                if not isinstance(child[0], str) and not isinstance(child[0], bytes):
                    all_terminals = False
                    break
            if all_terminals:
                return
            for j, child in enumerate(node):
                cid_enc = enc_g.number_of_nodes()
                cid_dec = dec_g.number_of_nodes()

                # Avoid leaf nodes
                tagid_enc = self._vocab.encode_token("{}_1".format(child.label()))
                tagid_dec = self._vocab.encode_token("{}_{}".format(node.label(), j+1))
                # assert tagid_enc != UNK and tagid_dec != UNK
                enc_g.add_node(cid_enc, x=tagid_enc, mask=0)
                dec_g.add_node(cid_dec, x=tagid_dec, y=tagid_enc, pos=j, mask=0, depth=depth+1)
                enc_g.add_edge(cid_enc, id_enc)
                dec_g.add_edge(id_dec, cid_dec)
                if not isinstance(child[0], str) and not isinstance(child[0], bytes):
                    _rec_build(cid_enc, cid_dec, child, depth=depth + 1)

        if parse.label() == "ROOT" and len(parse) == 1:
            # Skip the root node
            parse = parse[0]
        root_tagid = self._vocab.encode_token("{}_1".format(parse.label()))
        enc_g.add_node(0, x=root_tagid, mask=1)
        dec_g.add_node(0, x=self._vocab.encode_token("ROOT_1"), y=root_tagid, pos=0, mask=1, depth=0)
        _rec_build(0, 0, parse)
        if failed:
            return None
        enc_graph = dgl.DGLGraph()
        enc_graph.from_networkx(enc_g, node_attrs=['x', 'mask'])
        dec_graph = dgl.DGLGraph()
        dec_graph.from_networkx(dec_g, node_attrs=['x', 'y', 'pos', 'mask', 'depth'])
        return enc_graph, dec_graph

    def trees(self):
        return self._trees
Пример #9
0
if OPTS.train or OPTS.all:
    dataset = MTDataset(src_corpus=train_src_corpus,
                        tgt_corpus=tgt_corpus,
                        src_vocab=src_vocab_path,
                        tgt_vocab=tgt_vocab_path,
                        batch_size=OPTS.batchtokens * gpu_num,
                        batch_type="token",
                        truncate=truncate_datapoints,
                        max_length=TRAINING_MAX_TOKENS,
                        n_valid_samples=500)
else:
    dataset = None

# Create the model
basic_options = dict(dataset=dataset,
                     src_vocab_size=Vocab(src_vocab_path).size(),
                     tgt_vocab_size=Vocab(tgt_vocab_path).size(),
                     hidden_size=OPTS.hiddensz,
                     embed_size=OPTS.embedsz,
                     n_att_heads=OPTS.heads,
                     shard_size=OPTS.shard,
                     seed=OPTS.seed)

nmt = IndependentEnergyMT(latent_size=OPTS.latentdim)

# Training
if OPTS.train or OPTS.all:
    # Training code
    scheduler = SimpleScheduler(max_epoch=OPTS.epochs)
    # scheduler = TransformerScheduler(warm_steps=training_warmsteps, max_steps=training_maxsteps)
    lr = 0.0001 * gpu_num / 8
Пример #10
0
                key = "{}\t{}".format(src.strip(), cfg.strip())
                if key in export_map:
                    outf.write("{} <eoc> {}\n".format(export_map[key],
                                                      tgt.strip()))
                else:
                    outf.write("\n")

if OPTS.make_oracle_codes:
    if is_root_node():
        from nmtlab.utils import Vocab
        from lib_treedata import TreeDataGenerator
        import torch

        treegen = TreeDataGenerator(dataset_paths["test_cfg_corpus"],
                                    dataset_paths["cfg_vocab_path"])
        src_vocab = Vocab(dataset_paths["src_vocab_path"])
        samples = list(
            zip(open(dataset_paths["test_src_corpus"]),
                open(dataset_paths["test_cfg_corpus"])))

        print("loading", OPTS.model_path)
        assert os.path.exists(OPTS.model_path)
        autoencoder.load(OPTS.model_path)
        out_path = "{}/{}.test.export".format(
            DATA_ROOT,
            os.path.basename(OPTS.model_path).split(".")[0])
        autoencoder.train(False)
        if torch.cuda.is_available():
            autoencoder.cuda()
        with open(out_path, "w") as outf:
            print("code path", out_path)