def prepare_fields_transforms(opt): """Prepare or dump fields & transforms before training.""" transforms_cls = get_transforms_cls(opt._all_transform) specials = get_specials(opt, transforms_cls) fields = build_dynamic_fields(opt, src_specials=specials['src'], tgt_specials=specials['tgt']) # maybe prepare pretrained embeddings, if any prepare_pretrained_embeddings(opt, fields) if opt.dump_fields: save_fields(fields, opt.save_data, overwrite=opt.overwrite) if opt.dump_transforms or opt.n_sample != 0: transforms = make_transforms(opt, transforms_cls, fields) if opt.dump_transforms: save_transforms(transforms, opt.save_data, overwrite=opt.overwrite) if opt.n_sample != 0: logger.warning("`-n_sample` != 0: Training will not be started. " f"Stop after saving {opt.n_sample} samples/corpus.") save_transformed_sample(opt, transforms, n_sample=opt.n_sample) logger.info("Sample saved, please check it before restart training.") sys.exit() return fields, transforms_cls
def translate(opt): ArgumentParser.validate_translate_opts(opt) ArgumentParser._get_all_transform_translate(opt) ArgumentParser._validate_transforms_opts(opt) ArgumentParser.validate_translate_opts_dynamic(opt) logger = init_logger(opt.log_file) translator = build_translator(opt, logger=logger, report_score=True) data_reader = InferenceDataReader(opt.src, opt.tgt, opt.src_feats) # Build transforms transforms_cls = get_transforms_cls(opt._all_transform) transforms = make_transforms(opt, transforms_cls, translator.fields) data_transform = [ transforms[name] for name in opt.transforms if name in transforms ] transform = TransformPipe.build_from(data_transform) for i, (src_shard, tgt_shard, feats_shard) in enumerate(data_reader): logger.info("Translating shard %d." % i) translator.translate_dynamic(src=src_shard, transform=transform, src_feats=feats_shard, tgt=tgt_shard, batch_size=opt.batch_size, batch_type=opt.batch_type, attn_debug=opt.attn_debug, align_debug=opt.align_debug)
def test_vocab_required_transform(self): transforms_cls = get_transforms_cls(["bart", "switchout"]) opt = Namespace(seed=-1, switchout_temperature=1.0) # transforms that require vocab will not create if not provide vocab transforms = make_transforms(opt, transforms_cls, fields=None) self.assertEqual(len(transforms), 0) with self.assertRaises(ValueError): transforms_cls["switchout"](opt).warm_up(vocabs=None) transforms_cls["bart"](opt).warm_up(vocabs=None)
def build_dynamic_dataset_iter(fields, transforms_cls, opts, is_train=True, stride=1, offset=0): """Build `DynamicDatasetIter` from fields & opts.""" transforms = make_transforms(opts, transforms_cls, fields) corpora = get_corpora(opts, is_train) if corpora is None: assert not is_train, "only valid corpus is ignorable." return None return DynamicDatasetIter.from_opts( corpora, transforms, fields, opts, is_train, stride=stride, offset=offset)
def build_vocab_main(opts): """Apply transforms to samples of specified data and build vocab from it. Transforms that need vocab will be disabled in this. Built vocab is saved in plain text format as following and can be pass as `-src_vocab` (and `-tgt_vocab`) when training: ``` <tok_0>\t<count_0> <tok_1>\t<count_1> ``` """ ArgumentParser.validate_prepare_opts(opts, build_vocab_only=True) assert opts.n_sample == -1 or opts.n_sample > 1, \ f"Illegal argument n_sample={opts.n_sample}." logger = init_logger() set_random_seed(opts.seed, False) transforms_cls = get_transforms_cls(opts._all_transform) fields = None transforms = make_transforms(opts, transforms_cls, fields) logger.info(f"Counter vocab from {opts.n_sample} samples.") src_counter, tgt_counter, src_feats_counter = build_vocab( opts, transforms, n_sample=opts.n_sample) logger.info(f"Counters src:{len(src_counter)}") logger.info(f"Counters tgt:{len(tgt_counter)}") for feat_name, feat_counter in src_feats_counter.items(): logger.info(f"Counters {feat_name}:{len(feat_counter)}") def save_counter(counter, save_path): check_path(save_path, exist_ok=opts.overwrite, log=logger.warning) with open(save_path, "w", encoding="utf8") as fo: for tok, count in counter.most_common(): fo.write(tok + "\t" + str(count) + "\n") if opts.share_vocab: src_counter += tgt_counter tgt_counter = src_counter logger.info(f"Counters after share:{len(src_counter)}") save_counter(src_counter, opts.src_vocab) else: save_counter(src_counter, opts.src_vocab) save_counter(tgt_counter, opts.tgt_vocab) for k, v in src_feats_counter.items(): save_counter(v, opts.src_feats_vocab[k])
def build_vocab_main(opts): """Apply transforms to samples of specified data and build vocab from it. Transforms that need vocab will be disabled in this. Built vocab is saved in plain text format as following and can be pass as `-src_vocab` (and `-tgt_vocab`) when training: ``` <tok_0>\t<count_0> <tok_1>\t<count_1> ``` """ ArgumentParser.validate_prepare_opts(opts, build_vocab_only=True) assert opts.n_sample == -1 or opts.n_sample > 1, \ f"Illegal argument n_sample={opts.n_sample}." logger = init_logger() set_random_seed(opts.seed, False) transforms_cls = get_transforms_cls(opts._all_transform) fields = None transforms = make_transforms(opts, transforms_cls, fields) logger.info(f"Counter vocab from {opts.n_sample} samples.") src_counter, tgt_counter = save_transformed_sample(opts, transforms, n_sample=opts.n_sample, build_vocab=True) logger.info(f"Counters src:{len(src_counter)}") logger.info(f"Counters tgt:{len(tgt_counter)}") if opts.share_vocab: src_counter += tgt_counter tgt_counter = src_counter logger.info(f"Counters after share:{len(src_counter)}") def save_counter(counter, save_path): with open(save_path, "w") as fo: for tok, count in counter.most_common(): fo.write(tok + "\t" + str(count) + "\n") save_counter(src_counter, opts.save_data + '.vocab.src') save_counter(tgt_counter, opts.save_data + '.vocab.tgt')