示例#1
0
 def test_transform_register(self):
     builtin_transform = [
         "filtertoolong",
         "prefix",
         "sentencepiece",
         "bpe",
         "onmt_tokenize",
         "bart",
         "switchout",
         "tokendrop",
         "tokenmask",
     ]
     get_transforms_cls(builtin_transform)
示例#2
0
    def test_prefix(self):
        prefix_cls = get_transforms_cls(["prefix"])["prefix"]
        corpora = yaml.safe_load("""
            trainset:
                path_src: data/src-train.txt
                path_tgt: data/tgt-train.txt
                transforms: [prefix]
                weight: 1
                src_prefix: "⦅_pf_src⦆"
                tgt_prefix: "⦅_pf_tgt⦆"
        """)
        opt = Namespace(data=corpora, seed=-1)
        prefix_transform = prefix_cls(opt)
        prefix_transform.warm_up()
        self.assertIn("trainset", prefix_transform.prefix_dict)

        ex_in = {
            "src": ["Hello", "world", "."],
            "tgt": ["Bonjour", "le", "monde", "."],
        }
        with self.assertRaises(ValueError):
            prefix_transform.apply(ex_in)
            prefix_transform.apply(ex_in, corpus_name="validset")
        ex_out = prefix_transform.apply(ex_in, corpus_name="trainset")
        self.assertEqual(ex_out["src"][0], "⦅_pf_src⦆")
        self.assertEqual(ex_out["tgt"][0], "⦅_pf_tgt⦆")
示例#3
0
 def test_sentencepiece(self):
     sp_cls = get_transforms_cls(["sentencepiece"])["sentencepiece"]
     base_opt = copy.copy(self.base_opts)
     base_opt["src_subword_model"] = "data/sample.sp.model"
     base_opt["tgt_subword_model"] = "data/sample.sp.model"
     opt = Namespace(**base_opt)
     sp_cls._validate_options(opt)
     sp_transform = sp_cls(opt)
     sp_transform.warm_up()
     ex = {
         "src": ["Hello", "world", "."],
         "tgt": ["Bonjour", "le", "monde", "."],
     }
     sp_transform.apply(ex, is_train=True)
     ex_gold = {
         "src": ["▁H", "el", "lo", "▁world", "▁."],
         "tgt": ["▁B", "on", "j", "o", "ur", "▁le", "▁m", "on", "de", "▁."],
     }
     self.assertEqual(ex, ex_gold)
     # test SP regularization:
     sp_transform.src_subword_nbest = 4
     tokens = ["Another", "world", "."]
     gold_sp = ["▁An", "other", "▁world", "▁."]
     # 1. enable regularization for training example
     after_sp = sp_transform._tokenize(tokens, is_train=True)
     self.assertEqual(after_sp, ["▁An", "o", "ther", "▁world", "▁."])
     # 2. disable regularization for not training example
     after_sp = sp_transform._tokenize(tokens, is_train=False)
     self.assertEqual(after_sp, gold_sp)
示例#4
0
def prepare_fields_transforms(opt):
    """Prepare or dump fields & transforms before training."""
    transforms_cls = get_transforms_cls(opt._all_transform)
    specials = get_specials(opt, transforms_cls)

    fields = build_dynamic_fields(opt,
                                  src_specials=specials['src'],
                                  tgt_specials=specials['tgt'])

    # maybe prepare pretrained embeddings, if any
    prepare_pretrained_embeddings(opt, fields)

    if opt.dump_fields:
        save_fields(fields, opt.save_data, overwrite=opt.overwrite)
    if opt.dump_transforms or opt.n_sample != 0:
        transforms = make_transforms(opt, transforms_cls, fields)
    if opt.dump_transforms:
        save_transforms(transforms, opt.save_data, overwrite=opt.overwrite)
    if opt.n_sample != 0:
        logger.warning("`-n_sample` != 0: Training will not be started. "
                       f"Stop after saving {opt.n_sample} samples/corpus.")
        save_transformed_sample(opt, transforms, n_sample=opt.n_sample)
        logger.info("Sample saved, please check it before restart training.")
        sys.exit()
    return fields, transforms_cls
示例#5
0
def translate(opt):
    ArgumentParser.validate_translate_opts(opt)
    ArgumentParser._get_all_transform_translate(opt)
    ArgumentParser._validate_transforms_opts(opt)
    ArgumentParser.validate_translate_opts_dynamic(opt)
    logger = init_logger(opt.log_file)

    translator = build_translator(opt, logger=logger, report_score=True)

    data_reader = InferenceDataReader(opt.src, opt.tgt, opt.src_feats)

    # Build transforms
    transforms_cls = get_transforms_cls(opt._all_transform)
    transforms = make_transforms(opt, transforms_cls, translator.fields)
    data_transform = [
        transforms[name] for name in opt.transforms if name in transforms
    ]
    transform = TransformPipe.build_from(data_transform)

    for i, (src_shard, tgt_shard, feats_shard) in enumerate(data_reader):
        logger.info("Translating shard %d." % i)
        translator.translate_dynamic(src=src_shard,
                                     transform=transform,
                                     src_feats=feats_shard,
                                     tgt=tgt_shard,
                                     batch_size=opt.batch_size,
                                     batch_type=opt.batch_type,
                                     attn_debug=opt.attn_debug,
                                     align_debug=opt.align_debug)
示例#6
0
 def test_vocab_required_transform(self):
     transforms_cls = get_transforms_cls(["bart", "switchout"])
     opt = Namespace(seed=-1, switchout_temperature=1.0)
     # transforms that require vocab will not create if not provide vocab
     transforms = make_transforms(opt, transforms_cls, fields=None)
     self.assertEqual(len(transforms), 0)
     with self.assertRaises(ValueError):
         transforms_cls["switchout"](opt).warm_up(vocabs=None)
         transforms_cls["bart"](opt).warm_up(vocabs=None)
示例#7
0
 def test_filter_too_long(self):
     filter_cls = get_transforms_cls(["filtertoolong"])["filtertoolong"]
     opt = Namespace(src_seq_length=100, tgt_seq_length=100)
     filter_transform = filter_cls(opt)
     # filter_transform.warm_up()
     ex_in = {
         "src": ["Hello", "world", "."],
         "tgt": ["Bonjour", "le", "monde", "."],
     }
     ex_out = filter_transform.apply(ex_in)
     self.assertIs(ex_out, ex_in)
     filter_transform.tgt_seq_length = 2
     ex_out = filter_transform.apply(ex_in)
     self.assertIsNone(ex_out)
示例#8
0
 def test_transform_pipe(self):
     # 1. Init first transform in the pipe
     prefix_cls = get_transforms_cls(["prefix"])["prefix"]
     corpora = yaml.safe_load("""
         trainset:
             path_src: data/src-train.txt
             path_tgt: data/tgt-train.txt
             transforms: [prefix, filtertoolong]
             weight: 1
             src_prefix: "⦅_pf_src⦆"
             tgt_prefix: "⦅_pf_tgt⦆"
     """)
     opt = Namespace(data=corpora, seed=-1)
     prefix_transform = prefix_cls(opt)
     prefix_transform.warm_up()
     # 2. Init second transform in the pipe
     filter_cls = get_transforms_cls(["filtertoolong"])["filtertoolong"]
     opt = Namespace(src_seq_length=4, tgt_seq_length=4)
     filter_transform = filter_cls(opt)
     # 3. Sequential combine them into a transform pipe
     transform_pipe = TransformPipe.build_from(
         [prefix_transform, filter_transform])
     ex = {
         "src": ["Hello", ",", "world", "."],
         "tgt": ["Bonjour", "le", "monde", "."],
     }
     # 4. apply transform pipe for example
     ex_after = transform_pipe.apply(copy.deepcopy(ex),
                                     corpus_name="trainset")
     # 5. example after the pipe exceed the length limit, thus filtered
     self.assertIsNone(ex_after)
     # 6. Transform statistics registed (here for filtertoolong)
     self.assertTrue(len(transform_pipe.statistics.observables) > 0)
     msg = transform_pipe.statistics.report()
     self.assertIsNotNone(msg)
     # 7. after report, statistics become empty as a fresh start
     self.assertTrue(len(transform_pipe.statistics.observables) == 0)
示例#9
0
def build_vocab_main(opts):
    """Apply transforms to samples of specified data and build vocab from it.

    Transforms that need vocab will be disabled in this.
    Built vocab is saved in plain text format as following and can be pass as
    `-src_vocab` (and `-tgt_vocab`) when training:
    ```
    <tok_0>\t<count_0>
    <tok_1>\t<count_1>
    ```
    """

    ArgumentParser.validate_prepare_opts(opts, build_vocab_only=True)
    assert opts.n_sample == -1 or opts.n_sample > 1, \
        f"Illegal argument n_sample={opts.n_sample}."

    logger = init_logger()
    set_random_seed(opts.seed, False)
    transforms_cls = get_transforms_cls(opts._all_transform)
    fields = None

    transforms = make_transforms(opts, transforms_cls, fields)

    logger.info(f"Counter vocab from {opts.n_sample} samples.")
    src_counter, tgt_counter, src_feats_counter = build_vocab(
        opts, transforms, n_sample=opts.n_sample)

    logger.info(f"Counters src:{len(src_counter)}")
    logger.info(f"Counters tgt:{len(tgt_counter)}")
    for feat_name, feat_counter in src_feats_counter.items():
        logger.info(f"Counters {feat_name}:{len(feat_counter)}")

    def save_counter(counter, save_path):
        check_path(save_path, exist_ok=opts.overwrite, log=logger.warning)
        with open(save_path, "w", encoding="utf8") as fo:
            for tok, count in counter.most_common():
                fo.write(tok + "\t" + str(count) + "\n")

    if opts.share_vocab:
        src_counter += tgt_counter
        tgt_counter = src_counter
        logger.info(f"Counters after share:{len(src_counter)}")
        save_counter(src_counter, opts.src_vocab)
    else:
        save_counter(src_counter, opts.src_vocab)
        save_counter(tgt_counter, opts.tgt_vocab)

    for k, v in src_feats_counter.items():
        save_counter(v, opts.src_feats_vocab[k])
示例#10
0
 def test_tokenmask(self):
     tokenmask_cls = get_transforms_cls(["tokenmask"])["tokenmask"]
     opt = Namespace(seed=3434, tokenmask_temperature=0.1)
     tokenmask_transform = tokenmask_cls(opt)
     tokenmask_transform.warm_up()
     ex = {
         "src": ["Hello", ",", "world", "."],
         "tgt": ["Bonjour", "le", "monde", "."],
     }
     # Not apply token mask for not training example
     ex_after = tokenmask_transform.apply(copy.deepcopy(ex), is_train=False)
     self.assertEqual(ex_after, ex)
     # apply token mask for training example
     ex_after = tokenmask_transform.apply(copy.deepcopy(ex), is_train=True)
     self.assertNotEqual(ex_after, ex)
示例#11
0
 def test_transform_specials(self):
     transforms_cls = get_transforms_cls(["prefix"])
     corpora = yaml.safe_load("""
         trainset:
             path_src: data/src-train.txt
             path_tgt: data/tgt-train.txt
             transforms: ["prefix"]
             weight: 1
             src_prefix: "⦅_pf_src⦆"
             tgt_prefix: "⦅_pf_tgt⦆"
     """)
     opt = Namespace(data=corpora)
     specials = get_specials(opt, transforms_cls)
     specials_expected = {"src": {"⦅_pf_src⦆"}, "tgt": {"⦅_pf_tgt⦆"}}
     self.assertEqual(specials, specials_expected)
示例#12
0
 def test_bpe(self):
     bpe_cls = get_transforms_cls(["bpe"])["bpe"]
     opt = Namespace(**self.base_opts)
     bpe_cls._validate_options(opt)
     bpe_transform = bpe_cls(opt)
     bpe_transform.warm_up()
     ex = {
         "src": ["Hello", "world", "."],
         "tgt": ["Bonjour", "le", "monde", "."],
     }
     bpe_transform.apply(ex, is_train=True)
     ex_gold = {
         "src": ["H@@", "ell@@", "o", "world", "."],
         "tgt": ["B@@", "on@@", "j@@", "our", "le", "mon@@", "de", "."],
     }
     self.assertEqual(ex, ex_gold)
     # test BPE-dropout:
     bpe_transform.dropout["src"] = 1.0
     tokens = ["Another", "world", "."]
     gold_bpe = ["A@@", "no@@", "ther", "world", "."]
     gold_dropout = [
         "A@@",
         "n@@",
         "o@@",
         "t@@",
         "h@@",
         "e@@",
         "r",
         "w@@",
         "o@@",
         "r@@",
         "l@@",
         "d",
         ".",
     ]
     # 1. disable bpe dropout for not training example
     after_bpe = bpe_transform._tokenize(tokens, is_train=False)
     self.assertEqual(after_bpe, gold_bpe)
     # 2. enable bpe dropout for training example
     after_bpe = bpe_transform._tokenize(tokens, is_train=True)
     self.assertEqual(after_bpe, gold_dropout)
     # 3. (NOTE) disable dropout won't take effect if already seen
     # this is caused by the cache mechanism in bpe:
     # return cached subword if the original token is seen when no dropout
     after_bpe2 = bpe_transform._tokenize(tokens, is_train=False)
     self.assertEqual(after_bpe2, gold_dropout)
示例#13
0
def build_vocab_main(opts):
    """Apply transforms to samples of specified data and build vocab from it.

    Transforms that need vocab will be disabled in this.
    Built vocab is saved in plain text format as following and can be pass as
    `-src_vocab` (and `-tgt_vocab`) when training:
    ```
    <tok_0>\t<count_0>
    <tok_1>\t<count_1>
    ```
    """

    ArgumentParser.validate_prepare_opts(opts, build_vocab_only=True)
    assert opts.n_sample == -1 or opts.n_sample > 1, \
        f"Illegal argument n_sample={opts.n_sample}."

    logger = init_logger()
    set_random_seed(opts.seed, False)
    transforms_cls = get_transforms_cls(opts._all_transform)
    fields = None

    transforms = make_transforms(opts, transforms_cls, fields)

    logger.info(f"Counter vocab from {opts.n_sample} samples.")
    src_counter, tgt_counter = save_transformed_sample(opts,
                                                       transforms,
                                                       n_sample=opts.n_sample,
                                                       build_vocab=True)

    logger.info(f"Counters src:{len(src_counter)}")
    logger.info(f"Counters tgt:{len(tgt_counter)}")
    if opts.share_vocab:
        src_counter += tgt_counter
        tgt_counter = src_counter
        logger.info(f"Counters after share:{len(src_counter)}")

    def save_counter(counter, save_path):
        with open(save_path, "w") as fo:
            for tok, count in counter.most_common():
                fo.write(tok + "\t" + str(count) + "\n")

    save_counter(src_counter, opts.save_data + '.vocab.src')
    save_counter(tgt_counter, opts.save_data + '.vocab.tgt')
示例#14
0
    def test_inferfeats(self):
        inferfeats_cls = get_transforms_cls(["inferfeats"])["inferfeats"]
        opt = Namespace(reversible_tokenization="joiner",
                        prior_tokenization=False)
        inferfeats_transform = inferfeats_cls(opt)

        ex_in = {
            "src": [
                'however', '■,', 'according', 'to', 'the', 'logs', '■,', 'she',
                'is', 'hard', '■-■', 'working', '■.'
            ],
            "tgt": [
                'however', '■,', 'according', 'to', 'the', 'logs', '■,', 'she',
                'is', 'hard', '■-■', 'working', '■.'
            ]
        }
        ex_out = inferfeats_transform.apply(ex_in)
        self.assertIs(ex_out, ex_in)

        ex_in["src_feats"] = {
            "feat_0": ["A", "A", "A", "A", "B", "A", "A", "C"]
        }
        ex_out = inferfeats_transform.apply(ex_in)
        self.assertEqual(ex_out["src_feats"]["feat_0"], [
            "A", "<null>", "A", "A", "A", "B", "<null>", "A", "A", "C",
            "<null>", "C", "<null>"
        ])

        ex_in["src"] = [
            '⦅mrk_case_modifier_C⦆', 'however', '■,', 'according', 'to', 'the',
            'logs', '■,', '⦅mrk_begin_case_region_U⦆', 'she', 'is', 'hard',
            '■-■', 'working', '⦅mrk_end_case_region_U⦆', '■.'
        ]
        ex_in["src_feats"] = {
            "feat_0": ["A", "A", "A", "A", "B", "A", "A", "C"]
        }
        ex_out = inferfeats_transform.apply(ex_in)
        self.assertEqual(ex_out["src_feats"]["feat_0"], [
            "A", "A", "<null>", "A", "A", "A", "B", "<null>", "A", "A", "A",
            "C", "<null>", "C", "C", "<null>"
        ])
示例#15
0
def _init_train(opt):
    """Common initilization stuff for all training process."""
    ArgumentParser.validate_prepare_opts(opt)

    if opt.train_from:
        # Load checkpoint if we resume from a previous training.
        checkpoint = load_checkpoint(ckpt_path=opt.train_from)
        fields = load_fields(opt.save_data, checkpoint)
        transforms_cls = get_transforms_cls(opt._all_transform)
        if (hasattr(checkpoint["opt"], '_all_transform') and
                len(opt._all_transform.symmetric_difference(
                    checkpoint["opt"]._all_transform)) != 0):
            _msg = "configured transforms is different from checkpoint:"
            new_transf = opt._all_transform.difference(
                checkpoint["opt"]._all_transform)
            old_transf = checkpoint["opt"]._all_transform.difference(
                opt._all_transform)
            if len(new_transf) != 0:
                _msg += f" +{new_transf}"
            if len(old_transf) != 0:
                _msg += f" -{old_transf}."
            logger.warning(_msg)
        if opt.update_vocab:
            logger.info("Updating checkpoint vocabulary with new vocabulary")
            fields, transforms_cls = prepare_fields_transforms(opt)
    else:
        checkpoint = None
        fields, transforms_cls = prepare_fields_transforms(opt)

    # Report src and tgt vocab sizes
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))
    return checkpoint, fields, transforms_cls
示例#16
0
 def test_pyonmttok_bpe(self):
     onmttok_cls = get_transforms_cls(["onmt_tokenize"])["onmt_tokenize"]
     base_opt = copy.copy(self.base_opts)
     base_opt["src_subword_type"] = "bpe"
     base_opt["tgt_subword_type"] = "bpe"
     onmt_args = "{'mode': 'space', 'joiner_annotate': True}"
     base_opt["src_onmttok_kwargs"] = onmt_args
     base_opt["tgt_onmttok_kwargs"] = onmt_args
     opt = Namespace(**base_opt)
     onmttok_cls._validate_options(opt)
     onmttok_transform = onmttok_cls(opt)
     onmttok_transform.warm_up()
     ex = {
         "src": ["Hello", "world", "."],
         "tgt": ["Bonjour", "le", "monde", "."],
     }
     onmttok_transform.apply(ex, is_train=True)
     ex_gold = {
         "src": ["H■", "ell■", "o", "world", "."],
         "tgt": ["B■", "on■", "j■", "our", "le", "mon■", "de", "."],
     }
     self.assertEqual(ex, ex_gold)
示例#17
0
 def test_switchout(self):
     switchout_cls = get_transforms_cls(["switchout"])["switchout"]
     opt = Namespace(seed=3434, switchout_temperature=0.1)
     switchout_transform = switchout_cls(opt)
     with self.assertRaises(ValueError):
         # require vocabs to warm_up
         switchout_transform.warm_up(vocabs=None)
     vocabs = {
         "src": Namespace(itos=["A", "Fake", "vocab"]),
         "tgt": Namespace(itos=["A", "Fake", "vocab"]),
     }
     switchout_transform.warm_up(vocabs=vocabs)
     ex = {
         "src": ["Hello", ",", "world", "."],
         "tgt": ["Bonjour", "le", "monde", "."],
     }
     # Not apply token mask for not training example
     ex_after = switchout_transform.apply(copy.deepcopy(ex), is_train=False)
     self.assertEqual(ex_after, ex)
     # apply token mask for training example
     ex_after = switchout_transform.apply(copy.deepcopy(ex), is_train=True)
     self.assertNotEqual(ex_after, ex)
示例#18
0
 def test_pyonmttok_sp(self):
     onmttok_cls = get_transforms_cls(["onmt_tokenize"])["onmt_tokenize"]
     base_opt = copy.copy(self.base_opts)
     base_opt["src_subword_type"] = "sentencepiece"
     base_opt["tgt_subword_type"] = "sentencepiece"
     base_opt["src_subword_model"] = "data/sample.sp.model"
     base_opt["tgt_subword_model"] = "data/sample.sp.model"
     onmt_args = "{'mode': 'none', 'spacer_annotate': True}"
     base_opt["src_onmttok_kwargs"] = onmt_args
     base_opt["tgt_onmttok_kwargs"] = onmt_args
     opt = Namespace(**base_opt)
     onmttok_cls._validate_options(opt)
     onmttok_transform = onmttok_cls(opt)
     onmttok_transform.warm_up()
     ex = {
         "src": ["Hello", "world", "."],
         "tgt": ["Bonjour", "le", "monde", "."],
     }
     onmttok_transform.apply(ex, is_train=True)
     ex_gold = {
         "src": ["▁H", "el", "lo", "▁world", "▁."],
         "tgt": ["▁B", "on", "j", "o", "ur", "▁le", "▁m", "on", "de", "▁."],
     }
     self.assertEqual(ex, ex_gold)