예제 #1
0
def build_vocabulary(ds: Dataset) -> None:
    base_args = ([
        "-config", f"{path.join(ds.path, 'config.yaml')}", "-n_sample", "10000"
    ])

    parser = ArgumentParser(description='vocab.py')
    dynamic_prepare_opts(parser, build_vocab_only=True)

    options, unknown = parser.parse_known_args(base_args)
    build_vocab_main(options)

    return options, unknown
예제 #2
0
def get_default_opts():
    parser = ArgumentParser(description='data sample prepare')
    dynamic_prepare_opts(parser)

    default_opts = [
        '-config', 'data/data.yaml', '-src_vocab', 'data/vocab-train.src',
        '-tgt_vocab', 'data/vocab-train.tgt'
    ]

    opt = parser.parse_known_args(default_opts)[0]
    # Inject some dummy training options that may needed when build fields
    opt.copy_attn = False
    ArgumentParser.validate_prepare_opts(opt)
    return opt
예제 #3
0
import torch

import onmt
import onmt.inputters
import onmt.opts
from onmt.model_builder import build_embeddings, \
    build_encoder, build_decoder
from onmt.utils.parse import ArgumentParser

parser = ArgumentParser(description='train.py')
onmt.opts.model_opts(parser)
onmt.opts._add_train_general_opts(parser)

# -data option is required, but not used in this test, so dummy.
opt = parser.parse_known_args(['-data', 'dummy'])[0]


class TestModel(unittest.TestCase):
    def __init__(self, *args, **kwargs):
        super(TestModel, self).__init__(*args, **kwargs)
        self.opt = opt

    def get_field(self):
        src = onmt.inputters.get_fields("text", 0, 0)["src"]
        src.base_field.build_vocab([])
        return src

    def get_batch(self, source_l=3, bsize=1):
        # len x batch x nfeat
        test_src = torch.ones(source_l, bsize, 1).long()
예제 #4
0
import onmt
import onmt.inputters
import onmt.opts
from onmt.model_builder import build_embeddings, \
    build_encoder, build_decoder
from onmt.encoders.image_encoder import ImageEncoder
from onmt.encoders.audio_encoder import AudioEncoder
from onmt.utils.parse import ArgumentParser

parser = ArgumentParser(description='train.py')
onmt.opts.model_opts(parser)
onmt.opts.train_opts(parser)

# -data option is required, but not used in this test, so dummy.
opt = parser.parse_known_args(['-data', 'dummy'])[0]


class TestModel(unittest.TestCase):

    def __init__(self, *args, **kwargs):
        super(TestModel, self).__init__(*args, **kwargs)
        self.opt = opt

    def get_field(self):
        src = onmt.inputters.get_fields("text", 0, 0)["src"]
        src.base_field.build_vocab([])
        return src

    def get_batch(self, source_l=3, bsize=1):
        # len x batch x nfeat
예제 #5
0
from onmt.bin.train import main as train

if __name__ == '__main__':
    parser = ArgumentParser()

    # Simply add an argument for preprocess, train, translate
    mode = parser.add_mutually_exclusive_group()
    mode.add_argument("--preprocess",
                      dest='preprocess',
                      action='store_true',
                      help="Activate to preprocess with OpenNMT")
    mode.add_argument("--train",
                      dest='train',
                      action='store_true',
                      help="Activate to train with OpenNMT")
    mode.add_argument("--translate",
                      dest='translate',
                      action='store_true',
                      help="Activate to translate with OpenNMT")

    mode, remaining_args = parser.parse_known_args()

    if mode.preprocess:
        preprocess(remaining_args)
    elif mode.train:
        train(remaining_args)
    elif mode.translate:
        args = translate(remaining_args)

        # TODO compute scores directly after the translation is done
예제 #6
0
def preprocess(
    src_train, tgt_train,
    src_val, tgt_val,
    src_vocab_path, tgt_vocab_path,
    train_batch_size, valid_batch_size,
    device_code, train_num,
    vocab_max_size = 800000,
):
    # onmt.utils.logging.init_logger()

    # Build the vocabs
    parser = ArgumentParser(description = 'build_vocab.py')
    dynamic_prepare_opts(parser, build_vocab_only = True)
    base_args = (['-config', '/data7/private/qianhoude/data/config.yaml', '-n_sample', str(train_num)])
    opts, unknown = parser.parse_known_args(base_args)
    build_vocab_main(opts)

    # Initialize the frequency counter
    counters = defaultdict(Counter)
    # Load source vocab
    _src_vocab, _src_vocab_size = onmt.inputters.inputter._load_vocab(
        src_vocab_path,
        'src',
        counters
    )
    # load target vocab
    _tgt_vocab, _tgt_vocab_size = onmt.inputters.inputter._load_vocab(
        tgt_vocab_path,
        'tgt',
        counters
    )

    # Initialize fields
    src_nfeats, tgt_nfeats = 0, 0
    fields = onmt.inputters.inputter.get_fields(
        'text', src_nfeats, tgt_nfeats
    )

    # Build fields vocab
    share_vocab = False
    vocab_size_multiple = 1
    src_vocab_size = vocab_max_size
    tgt_vocab_size = vocab_max_size
    src_words_min_frequency = 1
    tgt_words_min_frequency = 1
    vocab_fields = onmt.inputters.inputter._build_fields_vocab(
        fields, counters, 'text', share_vocab,
        vocab_size_multiple,
        src_vocab_size, src_words_min_frequency,
        tgt_vocab_size, tgt_words_min_frequency,
    )

    src_text_field = vocab_fields['src'].base_field
    src_vocab = src_text_field.vocab
    src_padding = src_vocab.stoi[src_text_field.pad_token]

    tgt_text_field = vocab_fields['tgt'].base_field
    tgt_vocab = tgt_text_field.vocab
    tgt_padding = tgt_vocab.stoi[tgt_text_field.pad_token]

    # Build the ParallelCorpus
    corpus = ParallelCorpus('corpus', src_train, tgt_train)
    valid = ParallelCorpus('valid', src_val, tgt_val)

    # Build the training iterator
    train_iter = DynamicDatasetIter(
        corpora = {
            'corpus': corpus
        },
        corpora_info = {
            'corpus': {'weight': 1}
        },
        transforms = {},
        fields = vocab_fields,
        is_train = True,
        batch_type = 'tokens',
        batch_size = train_batch_size,
        batch_size_multiple = 1,
        data_type = 'text'
    )

    # Make sure the iteration happens on GPU 0 (-1 for CPU, N for GPU N)
    train_iter = iter(onmt.inputters.inputter.IterOnDevice(train_iter, device_code))

    # build the validation iterator
    valid_iter = DynamicDatasetIter(
        corpora = {
            'valid': valid
        },
        corpora_info = {
            'valid': {'weight': 1}
        },
        transforms = {},
        fields = vocab_fields,
        is_train = False,
        batch_type = 'sents',
        batch_size = valid_batch_size,
        batch_size_multiple = 1,
        data_type = 'text'
    )
    valid_iter = onmt.inputters.inputter.IterOnDevice(valid_iter, device_code)

    return train_iter, valid_iter, src_vocab, tgt_vocab